From 2ecba4f2a036fed41f85d69c8beea42470bcf07a Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Wed, 1 Mar 2023 22:59:54 +0100 Subject: [PATCH 01/95] first commit for the addition of the TabDDPM plugin --- .../core/models/tabular_ddpm/.lib/__init__.py | 12 + .../core/models/tabular_ddpm/.lib/data.py | 718 +++++++++++++ .../core/models/tabular_ddpm/.lib/deep.py | 168 +++ .../core/models/tabular_ddpm/.lib/env.py | 39 + .../core/models/tabular_ddpm/.lib/metrics.py | 158 +++ .../core/models/tabular_ddpm/.lib/util.py | 433 ++++++++ .../core/models/tabular_ddpm/.pipeline.py | 80 ++ .../core/models/tabular_ddpm/.sample.py | 159 +++ .../core/models/tabular_ddpm/.train.py | 156 +++ .../plugins/core/models/tabular_ddpm/.tune.py | 127 +++ .../core/models/tabular_ddpm/.utils_train.py | 88 ++ .../core/models/tabular_ddpm/README.md | 3 + .../core/models/tabular_ddpm/__init__.py | 2 + .../gaussian_multinomial_diffsuion.py | 992 ++++++++++++++++++ .../core/models/tabular_ddpm/modules.py | 486 +++++++++ .../core/models/tabular_ddpm/requirements.txt | 15 + .../plugins/core/models/tabular_ddpm/utils.py | 174 +++ src/synthcity/plugins/generic/plugin_ddpm.py | 217 ++++ third-party/tab-ddpm | 1 + 19 files changed, 4028 insertions(+) create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/__init__.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/data.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/deep.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/env.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/metrics.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/util.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.pipeline.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.sample.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.train.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.tune.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.utils_train.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/README.md create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/__init__.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/modules.py create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/requirements.txt create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/utils.py create mode 100644 src/synthcity/plugins/generic/plugin_ddpm.py create mode 160000 third-party/tab-ddpm diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/__init__.py new file mode 100644 index 00000000..54d6f6bb --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/__init__.py @@ -0,0 +1,12 @@ +import torch +from icecream import install + +torch.set_num_threads(1) +install() + +from . import env # noqa +from .data import * # noqa +from .deep import * # noqa +from .env import * # noqa +from .metrics import * # noqa +from .util import * # noqa diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/data.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/data.py new file mode 100644 index 00000000..912ce259 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/data.py @@ -0,0 +1,718 @@ +import hashlib +from collections import Counter +from copy import deepcopy +from dataclasses import astuple, dataclass, replace +from importlib.resources import path +from pathlib import Path +from typing import Any, Literal, Optional, Union, cast, Tuple, Dict, List + +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.pipeline import make_pipeline +import sklearn.preprocessing +import torch +import os +from category_encoders import LeaveOneOutEncoder +from sklearn.impute import SimpleImputer +from sklearn.preprocessing import StandardScaler +from scipy.spatial.distance import cdist + +from . import env, util +from .metrics import calculate_metrics as calculate_metrics_ +from .util import TaskType, load_json + +ArrayDict = Dict[str, np.ndarray] +TensorDict = Dict[str, torch.Tensor] + + +CAT_MISSING_VALUE = '__nan__' +CAT_RARE_VALUE = '__rare__' +Normalization = Literal['standard', 'quantile', 'minmax'] +NumNanPolicy = Literal['drop-rows', 'mean'] +CatNanPolicy = Literal['most_frequent'] +CatEncoding = Literal['one-hot', 'counter'] +YPolicy = Literal['default'] + + +class StandardScaler1d(StandardScaler): + def partial_fit(self, X, *args, **kwargs): + assert X.ndim == 1 + return super().partial_fit(X[:, None], *args, **kwargs) + + def transform(self, X, *args, **kwargs): + assert X.ndim == 1 + return super().transform(X[:, None], *args, **kwargs).squeeze(1) + + def inverse_transform(self, X, *args, **kwargs): + assert X.ndim == 1 + return super().inverse_transform(X[:, None], *args, **kwargs).squeeze(1) + + +def get_category_sizes(X: Union[torch.Tensor, np.ndarray]) -> List[int]: + XT = X.T.cpu().tolist() if isinstance(X, torch.Tensor) else X.T.tolist() + return [len(set(x)) for x in XT] + + +@dataclass(frozen=False) +class Dataset: + X_num: Optional[ArrayDict] + X_cat: Optional[ArrayDict] + y: ArrayDict + y_info: Dict[str, Any] + task_type: TaskType + n_classes: Optional[int] + + @classmethod + def from_dir(cls, dir_: Union[Path, str]) -> 'Dataset': + dir_ = Path(dir_) + splits = [k for k in ['train', 'val', 'test'] if dir_.joinpath(f'y_{k}.npy').exists()] + + def load(item) -> ArrayDict: + return { + x: cast(np.ndarray, np.load(dir_ / f'{item}_{x}.npy', allow_pickle=True)) # type: ignore[code] + for x in splits + } + + if Path(dir_ / 'info.json').exists(): + info = util.load_json(dir_ / 'info.json') + else: + info = None + return Dataset( + load('X_num') if dir_.joinpath('X_num_train.npy').exists() else None, + load('X_cat') if dir_.joinpath('X_cat_train.npy').exists() else None, + load('y'), + {}, + TaskType(info['task_type']), + info.get('n_classes'), + ) + + @property + def is_binclass(self) -> bool: + return self.task_type == TaskType.BINCLASS + + @property + def is_multiclass(self) -> bool: + return self.task_type == TaskType.MULTICLASS + + @property + def is_regression(self) -> bool: + return self.task_type == TaskType.REGRESSION + + @property + def n_num_features(self) -> int: + return 0 if self.X_num is None else self.X_num['train'].shape[1] + + @property + def n_cat_features(self) -> int: + return 0 if self.X_cat is None else self.X_cat['train'].shape[1] + + @property + def n_features(self) -> int: + return self.n_num_features + self.n_cat_features + + def size(self, part: Optional[str]) -> int: + return sum(map(len, self.y.values())) if part is None else len(self.y[part]) + + @property + def nn_output_dim(self) -> int: + if self.is_multiclass: + assert self.n_classes is not None + return self.n_classes + else: + return 1 + + def get_category_sizes(self, part: str) -> List[int]: + return [] if self.X_cat is None else get_category_sizes(self.X_cat[part]) + + def calculate_metrics( + self, + predictions: Dict[str, np.ndarray], + prediction_type: Optional[str], + ) -> Dict[str, Any]: + metrics = { + x: calculate_metrics_( + self.y[x], predictions[x], self.task_type, prediction_type, self.y_info + ) + for x in predictions + } + if self.task_type == TaskType.REGRESSION: + score_key = 'rmse' + score_sign = -1 + else: + score_key = 'accuracy' + score_sign = 1 + for part_metrics in metrics.values(): + part_metrics['score'] = score_sign * part_metrics[score_key] + return metrics + +def change_val(dataset: Dataset, val_size: float = 0.2): + # should be done before transformations + + y = np.concatenate([dataset.y['train'], dataset.y['val']], axis=0) + + ixs = np.arange(y.shape[0]) + if dataset.is_regression: + train_ixs, val_ixs = train_test_split(ixs, test_size=val_size, random_state=777) + else: + train_ixs, val_ixs = train_test_split(ixs, test_size=val_size, random_state=777, stratify=y) + + dataset.y['train'] = y[train_ixs] + dataset.y['val'] = y[val_ixs] + + if dataset.X_num is not None: + X_num = np.concatenate([dataset.X_num['train'], dataset.X_num['val']], axis=0) + dataset.X_num['train'] = X_num[train_ixs] + dataset.X_num['val'] = X_num[val_ixs] + + if dataset.X_cat is not None: + X_cat = np.concatenate([dataset.X_cat['train'], dataset.X_cat['val']], axis=0) + dataset.X_cat['train'] = X_cat[train_ixs] + dataset.X_cat['val'] = X_cat[val_ixs] + + return dataset + +def num_process_nans(dataset: Dataset, policy: Optional[NumNanPolicy]) -> Dataset: + assert dataset.X_num is not None + nan_masks = {k: np.isnan(v) for k, v in dataset.X_num.items()} + if not any(x.any() for x in nan_masks.values()): # type: ignore[code] + assert policy is None + return dataset + + assert policy is not None + if policy == 'drop-rows': + valid_masks = {k: ~v.any(1) for k, v in nan_masks.items()} + assert valid_masks[ + 'test' + ].all(), 'Cannot drop test rows, since this will affect the final metrics.' + new_data = {} + for data_name in ['X_num', 'X_cat', 'y']: + data_dict = getattr(dataset, data_name) + if data_dict is not None: + new_data[data_name] = { + k: v[valid_masks[k]] for k, v in data_dict.items() + } + dataset = replace(dataset, **new_data) + elif policy == 'mean': + new_values = np.nanmean(dataset.X_num['train'], axis=0) + X_num = deepcopy(dataset.X_num) + for k, v in X_num.items(): + num_nan_indices = np.where(nan_masks[k]) + v[num_nan_indices] = np.take(new_values, num_nan_indices[1]) + dataset = replace(dataset, X_num=X_num) + else: + assert util.raise_unknown('policy', policy) + return dataset + + +# Inspired by: https://github.com/yandex-research/rtdl/blob/a4c93a32b334ef55d2a0559a4407c8306ffeeaee/lib/data.py#L20 +def normalize( + X: ArrayDict, normalization: Normalization, seed: Optional[int], return_normalizer : bool = False +) -> ArrayDict: + X_train = X['train'] + if normalization == 'standard': + normalizer = sklearn.preprocessing.StandardScaler() + elif normalization == 'minmax': + normalizer = sklearn.preprocessing.MinMaxScaler() + elif normalization == 'quantile': + normalizer = sklearn.preprocessing.QuantileTransformer( + output_distribution='normal', + n_quantiles=max(min(X['train'].shape[0] // 30, 1000), 10), + subsample=1e9, + random_state=seed, + ) + # noise = 1e-3 + # if noise > 0: + # assert seed is not None + # stds = np.std(X_train, axis=0, keepdims=True) + # noise_std = noise / np.maximum(stds, noise) # type: ignore[code] + # X_train = X_train + noise_std * np.random.default_rng(seed).standard_normal( + # X_train.shape + # ) + else: + util.raise_unknown('normalization', normalization) + normalizer.fit(X_train) + if return_normalizer: + return {k: normalizer.transform(v) for k, v in X.items()}, normalizer + return {k: normalizer.transform(v) for k, v in X.items()} + + +def cat_process_nans(X: ArrayDict, policy: Optional[CatNanPolicy]) -> ArrayDict: + assert X is not None + nan_masks = {k: v == CAT_MISSING_VALUE for k, v in X.items()} + if any(x.any() for x in nan_masks.values()): # type: ignore[code] + if policy is None: + X_new = X + elif policy == 'most_frequent': + imputer = SimpleImputer(missing_values=CAT_MISSING_VALUE, strategy=policy) # type: ignore[code] + imputer.fit(X['train']) + X_new = {k: cast(np.ndarray, imputer.transform(v)) for k, v in X.items()} + else: + util.raise_unknown('categorical NaN policy', policy) + else: + assert policy is None + X_new = X + return X_new + + +def cat_drop_rare(X: ArrayDict, min_frequency: float) -> ArrayDict: + assert 0.0 < min_frequency < 1.0 + min_count = round(len(X['train']) * min_frequency) + X_new = {x: [] for x in X} + for column_idx in range(X['train'].shape[1]): + counter = Counter(X['train'][:, column_idx].tolist()) + popular_categories = {k for k, v in counter.items() if v >= min_count} + for part in X_new: + X_new[part].append( + [ + (x if x in popular_categories else CAT_RARE_VALUE) + for x in X[part][:, column_idx].tolist() + ] + ) + return {k: np.array(v).T for k, v in X_new.items()} + + +def cat_encode( + X: ArrayDict, + encoding: Optional[CatEncoding], + y_train: Optional[np.ndarray], + seed: Optional[int], + return_encoder : bool = False +) -> Tuple[ArrayDict, bool, Optional[Any]]: # (X, is_converted_to_numerical) + if encoding != 'counter': + y_train = None + + # Step 1. Map strings to 0-based ranges + + if encoding is None: + unknown_value = np.iinfo('int64').max - 3 + oe = sklearn.preprocessing.OrdinalEncoder( + handle_unknown='use_encoded_value', # type: ignore[code] + unknown_value=unknown_value, # type: ignore[code] + dtype='int64', # type: ignore[code] + ).fit(X['train']) + encoder = make_pipeline(oe) + encoder.fit(X['train']) + X = {k: encoder.transform(v) for k, v in X.items()} + max_values = X['train'].max(axis=0) + for part in X.keys(): + if part == 'train': continue + for column_idx in range(X[part].shape[1]): + X[part][X[part][:, column_idx] == unknown_value, column_idx] = ( + max_values[column_idx] + 1 + ) + if return_encoder: + return (X, False, encoder) + return (X, False) + + # Step 2. Encode. + + elif encoding == 'one-hot': + ohe = sklearn.preprocessing.OneHotEncoder( + handle_unknown='ignore', sparse=False, dtype=np.float32 # type: ignore[code] + ) + encoder = make_pipeline(ohe) + + # encoder.steps.append(('ohe', ohe)) + encoder.fit(X['train']) + X = {k: encoder.transform(v) for k, v in X.items()} + elif encoding == 'counter': + assert y_train is not None + assert seed is not None + loe = LeaveOneOutEncoder(sigma=0.1, random_state=seed, return_df=False) + encoder.steps.append(('loe', loe)) + encoder.fit(X['train'], y_train) + X = {k: encoder.transform(v).astype('float32') for k, v in X.items()} # type: ignore[code] + if not isinstance(X['train'], pd.DataFrame): + X = {k: v.values for k, v in X.items()} # type: ignore[code] + else: + util.raise_unknown('encoding', encoding) + + if return_encoder: + return X, True, encoder # type: ignore[code] + return (X, True) + + +def build_target( + y: ArrayDict, policy: Optional[YPolicy], task_type: TaskType +) -> Tuple[ArrayDict, Dict[str, Any]]: + info: Dict[str, Any] = {'policy': policy} + if policy is None: + pass + elif policy == 'default': + if task_type == TaskType.REGRESSION: + mean, std = float(y['train'].mean()), float(y['train'].std()) + y = {k: (v - mean) / std for k, v in y.items()} + info['mean'] = mean + info['std'] = std + else: + util.raise_unknown('policy', policy) + return y, info + + +@dataclass(frozen=True) +class Transformations: + seed: int = 0 + normalization: Optional[Normalization] = None + num_nan_policy: Optional[NumNanPolicy] = None + cat_nan_policy: Optional[CatNanPolicy] = None + cat_min_frequency: Optional[float] = None + cat_encoding: Optional[CatEncoding] = None + y_policy: Optional[YPolicy] = 'default' + + +def transform_dataset( + dataset: Dataset, + transformations: Transformations, + cache_dir: Optional[Path], + return_transforms: bool = False +) -> Dataset: + # WARNING: the order of transformations matters. Moreover, the current + # implementation is not ideal in that sense. + if cache_dir is not None: + transformations_md5 = hashlib.md5( + str(transformations).encode('utf-8') + ).hexdigest() + transformations_str = '__'.join(map(str, astuple(transformations))) + cache_path = ( + cache_dir / f'cache__{transformations_str}__{transformations_md5}.pickle' + ) + if cache_path.exists(): + cache_transformations, value = util.load_pickle(cache_path) + if transformations == cache_transformations: + print( + f"Using cached features: {cache_dir.name + '/' + cache_path.name}" + ) + return value + else: + raise RuntimeError(f'Hash collision for {cache_path}') + else: + cache_path = None + + if dataset.X_num is not None: + dataset = num_process_nans(dataset, transformations.num_nan_policy) + + num_transform = None + cat_transform = None + X_num = dataset.X_num + + if X_num is not None and transformations.normalization is not None: + X_num, num_transform = normalize( + X_num, + transformations.normalization, + transformations.seed, + return_normalizer=True + ) + num_transform = num_transform + + if dataset.X_cat is None: + assert transformations.cat_nan_policy is None + assert transformations.cat_min_frequency is None + # assert transformations.cat_encoding is None + X_cat = None + else: + X_cat = cat_process_nans(dataset.X_cat, transformations.cat_nan_policy) + if transformations.cat_min_frequency is not None: + X_cat = cat_drop_rare(X_cat, transformations.cat_min_frequency) + X_cat, is_num, cat_transform = cat_encode( + X_cat, + transformations.cat_encoding, + dataset.y['train'], + transformations.seed, + return_encoder=True + ) + if is_num: + X_num = ( + X_cat + if X_num is None + else {x: np.hstack([X_num[x], X_cat[x]]) for x in X_num} + ) + X_cat = None + + y, y_info = build_target(dataset.y, transformations.y_policy, dataset.task_type) + + dataset = replace(dataset, X_num=X_num, X_cat=X_cat, y=y, y_info=y_info) + dataset.num_transform = num_transform + dataset.cat_transform = cat_transform + + if cache_path is not None: + util.dump_pickle((transformations, dataset), cache_path) + # if return_transforms: + # return dataset, num_transform, cat_transform + return dataset + + +def build_dataset( + path: Union[str, Path], + transformations: Transformations, + cache: bool +) -> Dataset: + path = Path(path) + dataset = Dataset.from_dir(path) + return transform_dataset(dataset, transformations, path if cache else None) + + +def prepare_tensors( + dataset: Dataset, device: Union[str, torch.device] +) -> Tuple[Optional[TensorDict], Optional[TensorDict], TensorDict]: + X_num, X_cat, Y = ( + None if x is None else {k: torch.as_tensor(v) for k, v in x.items()} + for x in [dataset.X_num, dataset.X_cat, dataset.y] + ) + if device.type != 'cpu': + X_num, X_cat, Y = ( + None if x is None else {k: v.to(device) for k, v in x.items()} + for x in [X_num, X_cat, Y] + ) + assert X_num is not None + assert Y is not None + if not dataset.is_multiclass: + Y = {k: v.float() for k, v in Y.items()} + return X_num, X_cat, Y + +############### +## DataLoader## +############### + +class TabDataset(torch.utils.data.Dataset): + def __init__( + self, dataset : Dataset, split : Literal['train', 'val', 'test'] + ): + super().__init__() + + self.X_num = torch.from_numpy(dataset.X_num[split]) if dataset.X_num is not None else None + self.X_cat = torch.from_numpy(dataset.X_cat[split]) if dataset.X_cat is not None else None + self.y = torch.from_numpy(dataset.y[split]) + + assert self.y is not None + assert self.X_num is not None or self.X_cat is not None + + def __len__(self): + return len(self.y) + + def __getitem__(self, idx): + out_dict = { + 'y': self.y[idx].long() if self.y is not None else None, + } + + x = np.empty((0,)) + if self.X_num is not None: + x = self.X_num[idx] + if self.X_cat is not None: + x = torch.cat([x, self.X_cat[idx]], dim=0) + return x.float(), out_dict + +def prepare_dataloader( + dataset : Dataset, + split : str, + batch_size: int, +): + + torch_dataset = TabDataset(dataset, split) + loader = torch.utils.data.DataLoader( + torch_dataset, + batch_size=batch_size, + shuffle=(split == 'train'), + num_workers=1, + ) + while True: + yield from loader + +def prepare_torch_dataloader( + dataset : Dataset, + split : str, + shuffle : bool, + batch_size: int, +) -> torch.utils.data.DataLoader: + + torch_dataset = TabDataset(dataset, split) + loader = torch.utils.data.DataLoader(torch_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1) + + return loader + +def dataset_from_csv(paths : Dict[str, str], cat_features, target, T): + assert 'train' in paths + y = {} + X_num = {} + X_cat = {} if len(cat_features) else None + for split in paths.keys(): + df = pd.read_csv(paths[split]) + y[split] = df[target].to_numpy().astype(float) + if X_cat is not None: + X_cat[split] = df[cat_features].to_numpy().astype(str) + X_num[split] = df.drop(cat_features + [target], axis=1).to_numpy().astype(float) + + dataset = Dataset(X_num, X_cat, y, {}, None, len(np.unique(y['train']))) + return transform_dataset(dataset, T, None) + +class FastTensorDataLoader: + """ + A DataLoader-like object for a set of tensors that can be much faster than + TensorDataset + DataLoader because dataloader grabs individual indices of + the dataset and calls cat (slow). + Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 + """ + def __init__(self, *tensors, batch_size=32, shuffle=False): + """ + Initialize a FastTensorDataLoader. + :param *tensors: tensors to store. Must have the same length @ dim 0. + :param batch_size: batch size to load. + :param shuffle: if True, shuffle the data *in-place* whenever an + iterator is created out of this object. + :returns: A FastTensorDataLoader. + """ + assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) + self.tensors = tensors + + self.dataset_len = self.tensors[0].shape[0] + self.batch_size = batch_size + self.shuffle = shuffle + + # Calculate # batches + n_batches, remainder = divmod(self.dataset_len, self.batch_size) + if remainder > 0: + n_batches += 1 + self.n_batches = n_batches + def __iter__(self): + if self.shuffle: + r = torch.randperm(self.dataset_len) + self.tensors = [t[r] for t in self.tensors] + self.i = 0 + return self + + def __next__(self): + if self.i >= self.dataset_len: + raise StopIteration + batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors) + self.i += self.batch_size + return batch + + def __len__(self): + return self.n_batches + +def prepare_fast_dataloader( + D : Dataset, + split : str, + batch_size: int +): + if D.X_cat is not None: + if D.X_num is not None: + X = torch.from_numpy(np.concatenate([D.X_num[split], D.X_cat[split]], axis=1)).float() + else: + X = torch.from_numpy(D.X_cat[split]).float() + else: + X = torch.from_numpy(D.X_num[split]).float() + y = torch.from_numpy(D.y[split]) + dataloader = FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=(split=='train')) + while True: + yield from dataloader + +def prepare_fast_torch_dataloader( + D : Dataset, + split : str, + batch_size: int +): + if D.X_cat is not None: + X = torch.from_numpy(np.concatenate([D.X_num[split], D.X_cat[split]], axis=1)).float() + else: + X = torch.from_numpy(D.X_num[split]).float() + y = torch.from_numpy(D.y[split]) + dataloader = FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=(split=='train')) + return dataloader + +def round_columns(X_real, X_synth, columns): + for col in columns: + uniq = np.unique(X_real[:,col]) + dist = cdist(X_synth[:, col][:, np.newaxis].astype(float), uniq[:, np.newaxis].astype(float)) + X_synth[:, col] = uniq[dist.argmin(axis=1)] + return X_synth + +def concat_features(D : Dataset): + if D.X_num is None: + assert D.X_cat is not None + X = {k: pd.DataFrame(v, columns=range(D.n_features)) for k, v in D.X_cat.items()} + elif D.X_cat is None: + assert D.X_num is not None + X = {k: pd.DataFrame(v, columns=range(D.n_features)) for k, v in D.X_num.items()} + else: + X = { + part: pd.concat( + [ + pd.DataFrame(D.X_num[part], columns=range(D.n_num_features)), + pd.DataFrame( + D.X_cat[part], + columns=range(D.n_num_features, D.n_features), + ), + ], + axis=1, + ) + for part in D.y.keys() + } + + return X + +def concat_to_pd(X_num, X_cat, y): + if X_num is None: + return pd.concat([ + pd.DataFrame(X_cat, columns=list(range(X_cat.shape[1]))), + pd.DataFrame(y, columns=['y']) + ], axis=1) + if X_cat is not None: + return pd.concat([ + pd.DataFrame(X_num, columns=list(range(X_num.shape[1]))), + pd.DataFrame(X_cat, columns=list(range(X_num.shape[1], X_num.shape[1] + X_cat.shape[1]))), + pd.DataFrame(y, columns=['y']) + ], axis=1) + return pd.concat([ + pd.DataFrame(X_num, columns=list(range(X_num.shape[1]))), + pd.DataFrame(y, columns=['y']) + ], axis=1) + +def read_pure_data(path, split='train'): + y = np.load(os.path.join(path, f'y_{split}.npy'), allow_pickle=True) + X_num = None + X_cat = None + if os.path.exists(os.path.join(path, f'X_num_{split}.npy')): + X_num = np.load(os.path.join(path, f'X_num_{split}.npy'), allow_pickle=True) + if os.path.exists(os.path.join(path, f'X_cat_{split}.npy')): + X_cat = np.load(os.path.join(path, f'X_cat_{split}.npy'), allow_pickle=True) + + return X_num, X_cat, y + +def read_changed_val(path, val_size=0.2): + path = Path(path) + X_num_train, X_cat_train, y_train = read_pure_data(path, 'train') + X_num_val, X_cat_val, y_val = read_pure_data(path, 'val') + is_regression = load_json(path / 'info.json')['task_type'] == 'regression' + + y = np.concatenate([y_train, y_val], axis=0) + + ixs = np.arange(y.shape[0]) + if is_regression: + train_ixs, val_ixs = train_test_split(ixs, test_size=val_size, random_state=777) + else: + train_ixs, val_ixs = train_test_split(ixs, test_size=val_size, random_state=777, stratify=y) + y_train = y[train_ixs] + y_val = y[val_ixs] + + if X_num_train is not None: + X_num = np.concatenate([X_num_train, X_num_val], axis=0) + X_num_train = X_num[train_ixs] + X_num_val = X_num[val_ixs] + + if X_cat_train is not None: + X_cat = np.concatenate([X_cat_train, X_cat_val], axis=0) + X_cat_train = X_cat[train_ixs] + X_cat_val = X_cat[val_ixs] + + return X_num_train, X_cat_train, y_train, X_num_val, X_cat_val, y_val + +############# + +def load_dataset_info(dataset_dir_name: str) -> Dict[str, Any]: + path = Path("data/" + dataset_dir_name) + info = util.load_json(path / 'info.json') + info['size'] = info['train_size'] + info['val_size'] + info['test_size'] + info['n_features'] = info['n_num_features'] + info['n_cat_features'] + info['path'] = path + return info diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/deep.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/deep.py new file mode 100644 index 00000000..aeed3e2a --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/deep.py @@ -0,0 +1,168 @@ +import statistics +from dataclasses import dataclass +from typing import Any, Callable, Literal, cast + +import rtdl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import zero +from torch import Tensor + +from .util import TaskType + + +def cos_sin(x: Tensor) -> Tensor: + return torch.cat([torch.cos(x), torch.sin(x)], -1) + + +@dataclass +class PeriodicOptions: + n: int # the output size is 2 * n + sigma: float + trainable: bool + initialization: Literal['log-linear', 'normal'] + + +class Periodic(nn.Module): + def __init__(self, n_features: int, options: PeriodicOptions) -> None: + super().__init__() + if options.initialization == 'log-linear': + coefficients = options.sigma ** (torch.arange(options.n) / options.n) + coefficients = coefficients[None].repeat(n_features, 1) + else: + assert options.initialization == 'normal' + coefficients = torch.normal(0.0, options.sigma, (n_features, options.n)) + if options.trainable: + self.coefficients = nn.Parameter(coefficients) # type: ignore[code] + else: + self.register_buffer('coefficients', coefficients) + + def forward(self, x: Tensor) -> Tensor: + assert x.ndim == 2 + return cos_sin(2 * torch.pi * self.coefficients[None] * x[..., None]) + + +def get_n_parameters(m: nn.Module): + return sum(x.numel() for x in m.parameters() if x.requires_grad) + + +def get_loss_fn(task_type: TaskType) -> Callable[..., Tensor]: + return ( + F.binary_cross_entropy_with_logits + if task_type == TaskType.BINCLASS + else F.cross_entropy + if task_type == TaskType.MULTICLASS + else F.mse_loss + ) + + +def default_zero_weight_decay_condition(module_name, module, parameter_name, parameter): + del module_name, parameter + return parameter_name.endswith('bias') or isinstance( + module, + ( + nn.BatchNorm1d, + nn.LayerNorm, + nn.InstanceNorm1d, + rtdl.CLSToken, + rtdl.NumericalFeatureTokenizer, + rtdl.CategoricalFeatureTokenizer, + Periodic, + ), + ) + + +def split_parameters_by_weight_decay( + model: nn.Module, zero_weight_decay_condition=default_zero_weight_decay_condition +) -> list[dict[str, Any]]: + parameters_info = {} + for module_name, module in model.named_modules(): + for parameter_name, parameter in module.named_parameters(): + full_parameter_name = ( + f'{module_name}.{parameter_name}' if module_name else parameter_name + ) + parameters_info.setdefault(full_parameter_name, ([], parameter))[0].append( + zero_weight_decay_condition( + module_name, module, parameter_name, parameter + ) + ) + params_with_wd = {'params': []} + params_without_wd = {'params': [], 'weight_decay': 0.0} + for full_parameter_name, (results, parameter) in parameters_info.items(): + (params_without_wd if any(results) else params_with_wd)['params'].append( + parameter + ) + return [params_with_wd, params_without_wd] + + +def make_optimizer( + config: dict[str, Any], + parameter_groups, +) -> optim.Optimizer: + if config['optimizer'] == 'FT-Transformer-default': + return optim.AdamW(parameter_groups, lr=1e-4, weight_decay=1e-5) + return getattr(optim, config['optimizer'])( + parameter_groups, + **{x: config[x] for x in ['lr', 'weight_decay', 'momentum'] if x in config}, + ) + + +def get_lr(optimizer: optim.Optimizer) -> float: + return next(iter(optimizer.param_groups))['lr'] + + +def is_oom_exception(err: RuntimeError) -> bool: + return any( + x in str(err) + for x in [ + 'CUDA out of memory', + 'CUBLAS_STATUS_ALLOC_FAILED', + 'CUDA error: out of memory', + ] + ) + + +def train_with_auto_virtual_batch( + optimizer, + loss_fn, + step, + batch, + chunk_size: int, +) -> tuple[Tensor, int]: + batch_size = len(batch) + random_state = zero.random.get_state() + loss = None + while chunk_size != 0: + try: + zero.random.set_state(random_state) + optimizer.zero_grad() + if batch_size <= chunk_size: + loss = loss_fn(*step(batch)) + loss.backward() + else: + loss = None + for chunk in zero.iter_batches(batch, chunk_size): + chunk_loss = loss_fn(*step(chunk)) + chunk_loss = chunk_loss * (len(chunk) / batch_size) + chunk_loss.backward() + if loss is None: + loss = chunk_loss.detach() + else: + loss += chunk_loss.detach() + except RuntimeError as err: + if not is_oom_exception(err): + raise + chunk_size //= 2 + else: + break + if not chunk_size: + raise RuntimeError('Not enough memory even for batch_size=1') + optimizer.step() + return cast(Tensor, loss), chunk_size + + +def process_epoch_losses(losses: list[Tensor]) -> tuple[list[float], float]: + losses_ = torch.stack(losses).tolist() + return losses_, statistics.mean(losses_) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/env.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/env.py new file mode 100644 index 00000000..64be89d7 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/env.py @@ -0,0 +1,39 @@ +""" +Have not used in TabDDPM project. +""" + +import datetime +import os +import shutil +import typing as ty +from pathlib import Path + +PROJ = Path('tab-ddpm/').absolute().resolve() +EXP = PROJ / 'exp' +DATA = PROJ / 'data' + + +def get_path(path: ty.Union[str, Path]) -> Path: + if isinstance(path, str): + path = Path(path) + if not path.is_absolute(): + path = PROJ / path + return path.resolve() + + +def get_relative_path(path: ty.Union[str, Path]) -> Path: + return get_path(path).relative_to(PROJ) + + +def duplicate_path( + src: ty.Union[str, Path], alternative_project_dir: ty.Union[str, Path] +) -> None: + src = get_path(src) + alternative_project_dir = get_path(alternative_project_dir) + dst = alternative_project_dir / src.relative_to(PROJ) + dst.parent.mkdir(parents=True, exist_ok=True) + if dst.exists(): + dst = dst.with_name( + dst.name + '_' + datetime.datetime.now().strftime('%Y%m%dT%H%M%S') + ) + (shutil.copytree if src.is_dir() else shutil.copyfile)(src, dst) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/metrics.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/metrics.py new file mode 100644 index 00000000..bdcac817 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/metrics.py @@ -0,0 +1,158 @@ +import enum +from typing import Any, Optional, Tuple, Dict, Union, cast +from functools import partial + +import numpy as np +import scipy.special +import sklearn.metrics as skm + +from . import util +from .util import TaskType + + +class PredictionType(enum.Enum): + LOGITS = 'logits' + PROBS = 'probs' + +class MetricsReport: + def __init__(self, report: dict, task_type: TaskType): + self._res = {k: {} for k in report.keys()} + if task_type in (TaskType.BINCLASS, TaskType.MULTICLASS): + self._metrics_names = ["acc", "f1"] + for k in report.keys(): + self._res[k]["acc"] = report[k]["accuracy"] + self._res[k]["f1"] = report[k]["macro avg"]["f1-score"] + if task_type == TaskType.BINCLASS: + self._res[k]["roc_auc"] = report[k]["roc_auc"] + self._metrics_names.append("roc_auc") + + elif task_type == TaskType.REGRESSION: + self._metrics_names = ["r2", "rmse"] + for k in report.keys(): + self._res[k]["r2"] = report[k]["r2"] + self._res[k]["rmse"] = report[k]["rmse"] + else: + raise "Unknown TaskType!" + + def get_splits_names(self) -> list[str]: + return self._res.keys() + + def get_metrics_names(self) -> list[str]: + return self._metrics_names + + def get_metric(self, split: str, metric: str) -> float: + return self._res[split][metric] + + def get_val_score(self) -> float: + return self._res["val"]["r2"] if "r2" in self._res["val"] else self._res["val"]["f1"] + + def get_test_score(self) -> float: + return self._res["test"]["r2"] if "r2" in self._res["test"] else self._res["test"]["f1"] + + def print_metrics(self) -> None: + res = { + "val": {k: np.around(self._res["val"][k], 4) for k in self._res["val"]}, + "test": {k: np.around(self._res["test"][k], 4) for k in self._res["test"]} + } + + print("*"*100) + print("[val]") + print(res["val"]) + print("[test]") + print(res["test"]) + + return res + +class SeedsMetricsReport: + def __init__(self): + self._reports = [] + + def add_report(self, report: MetricsReport) -> None: + self._reports.append(report) + + def get_mean_std(self) -> dict: + res = {k: {} for k in ["train", "val", "test"]} + for split in self._reports[0].get_splits_names(): + for metric in self._reports[0].get_metrics_names(): + res[split][metric] = [x.get_metric(split, metric) for x in self._reports] + + agg_res = {k: {} for k in ["train", "val", "test"]} + for split in self._reports[0].get_splits_names(): + for metric in self._reports[0].get_metrics_names(): + for k, f in [("count", len), ("mean", np.mean), ("std", np.std)]: + agg_res[split][f"{metric}-{k}"] = f(res[split][metric]) + self._res = res + self._agg_res = agg_res + + return agg_res + + def print_result(self) -> dict: + res = {split: {k: float(np.around(self._agg_res[split][k], 4)) for k in self._agg_res[split]} for split in ["val", "test"]} + print("="*100) + print("EVAL RESULTS:") + print("[val]") + print(res["val"]) + print("[test]") + print(res["test"]) + print("="*100) + return res + +def calculate_rmse( + y_true: np.ndarray, y_pred: np.ndarray, std: Optional[float] +) -> float: + rmse = skm.mean_squared_error(y_true, y_pred) ** 0.5 + if std is not None: + rmse *= std + return rmse + + +def _get_labels_and_probs( + y_pred: np.ndarray, task_type: TaskType, prediction_type: Optional[PredictionType] +) -> Tuple[np.ndarray, Optional[np.ndarray]]: + assert task_type in (TaskType.BINCLASS, TaskType.MULTICLASS) + + if prediction_type is None: + return y_pred, None + + if prediction_type == PredictionType.LOGITS: + probs = ( + scipy.special.expit(y_pred) + if task_type == TaskType.BINCLASS + else scipy.special.softmax(y_pred, axis=1) + ) + elif prediction_type == PredictionType.PROBS: + probs = y_pred + else: + util.raise_unknown('prediction_type', prediction_type) + + assert probs is not None + labels = np.round(probs) if task_type == TaskType.BINCLASS else probs.argmax(axis=1) + return labels.astype('int64'), probs + + +def calculate_metrics( + y_true: np.ndarray, + y_pred: np.ndarray, + task_type: Union[str, TaskType], + prediction_type: Optional[Union[str, PredictionType]], + y_info: Dict[str, Any], +) -> Dict[str, Any]: + # Example: calculate_metrics(y_true, y_pred, 'binclass', 'logits', {}) + task_type = TaskType(task_type) + if prediction_type is not None: + prediction_type = PredictionType(prediction_type) + + if task_type == TaskType.REGRESSION: + assert prediction_type is None + assert 'std' in y_info + rmse = calculate_rmse(y_true, y_pred, y_info['std']) + r2 = skm.r2_score(y_true, y_pred) + result = {'rmse': rmse, 'r2': r2} + else: + labels, probs = _get_labels_and_probs(y_pred, task_type, prediction_type) + result = cast( + Dict[str, Any], skm.classification_report(y_true, labels, output_dict=True) + ) + if task_type == TaskType.BINCLASS: + result['roc_auc'] = skm.roc_auc_score(y_true, probs) + return result diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/util.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/util.py new file mode 100644 index 00000000..75e05c9c --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/util.py @@ -0,0 +1,433 @@ +import argparse +import atexit +import enum +import json +import os +import pickle +import shutil +import sys +import time +import uuid +from copy import deepcopy +from dataclasses import asdict, fields, is_dataclass +from pathlib import Path +from pprint import pprint +from typing import Any, Callable, List, Dict, Type, Optional, Tuple, TypeVar, Union, cast, get_args, get_origin + +import __main__ +import numpy as np +import tomli +import tomli_w +import torch +import zero + +from . import env + +RawConfig = Dict[str, Any] +Report = Dict[str, Any] +T = TypeVar('T') + + +class Part(enum.Enum): + TRAIN = 'train' + VAL = 'val' + TEST = 'test' + + def __str__(self) -> str: + return self.value + + +class TaskType(enum.Enum): + BINCLASS = 'binclass' + MULTICLASS = 'multiclass' + REGRESSION = 'regression' + + def __str__(self) -> str: + return self.value + + +class Timer(zero.Timer): + @classmethod + def launch(cls) -> 'Timer': + timer = cls() + timer.run() + return timer + + +def update_training_log(training_log, data, metrics): + def _update(log_part, data_part): + for k, v in data_part.items(): + if isinstance(v, dict): + _update(log_part.setdefault(k, {}), v) + elif isinstance(v, list): + log_part.setdefault(k, []).extend(v) + else: + log_part.setdefault(k, []).append(v) + + _update(training_log, data) + transposed_metrics = {} + for part, part_metrics in metrics.items(): + for metric_name, value in part_metrics.items(): + transposed_metrics.setdefault(metric_name, {})[part] = value + _update(training_log, transposed_metrics) + + +def raise_unknown(unknown_what: str, unknown_value: Any): + raise ValueError(f'Unknown {unknown_what}: {unknown_value}') + + +def _replace(data, condition, value): + def do(x): + if isinstance(x, dict): + return {k: do(v) for k, v in x.items()} + elif isinstance(x, list): + return [do(y) for y in x] + else: + return value if condition(x) else x + + return do(data) + + +_CONFIG_NONE = '__none__' + + +def unpack_config(config: RawConfig) -> RawConfig: + config = cast(RawConfig, _replace(config, lambda x: x == _CONFIG_NONE, None)) + return config + + +def pack_config(config: RawConfig) -> RawConfig: + config = cast(RawConfig, _replace(config, lambda x: x is None, _CONFIG_NONE)) + return config + + +def load_config(path: Union[Path, str]) -> Any: + with open(path, 'rb') as f: + return unpack_config(tomli.load(f)) + + +def dump_config(config: Any, path: Union[Path, str]) -> None: + with open(path, 'wb') as f: + tomli_w.dump(pack_config(config), f) + # check that there are no bugs in all these "pack/unpack" things + assert config == load_config(path) + + +def load_json(path: Union[Path, str], **kwargs) -> Any: + return json.loads(Path(path).read_text(), **kwargs) + + +def dump_json(x: Any, path: Union[Path, str], **kwargs) -> None: + kwargs.setdefault('indent', 4) + Path(path).write_text(json.dumps(x, **kwargs) + '\n') + + +def load_pickle(path: Union[Path, str], **kwargs) -> Any: + return pickle.loads(Path(path).read_bytes(), **kwargs) + + +def dump_pickle(x: Any, path: Union[Path, str], **kwargs) -> None: + Path(path).write_bytes(pickle.dumps(x, **kwargs)) + + +def load(path: Union[Path, str], **kwargs) -> Any: + return globals()[f'load_{Path(path).suffix[1:]}'](Path(path), **kwargs) + + +def dump(x: Any, path: Union[Path, str], **kwargs) -> Any: + return globals()[f'dump_{Path(path).suffix[1:]}'](x, Path(path), **kwargs) + + +def _get_output_item_path( + path: Union[str, Path], filename: str, must_exist: bool +) -> Path: + path = env.get_path(path) + if path.suffix == '.toml': + path = path.with_suffix('') + if path.is_dir(): + path = path / filename + else: + assert path.name == filename + assert path.parent.exists() + if must_exist: + assert path.exists() + return path + + +def load_report(path: Path) -> Report: + return load_json(_get_output_item_path(path, 'report.json', True)) + + +def dump_report(report: dict, path: Path) -> None: + dump_json(report, _get_output_item_path(path, 'report.json', False)) + + +def load_predictions(path: Path) -> Dict[str, np.ndarray]: + with np.load(_get_output_item_path(path, 'predictions.npz', True)) as predictions: + return {x: predictions[x] for x in predictions} + + +def dump_predictions(predictions: Dict[str, np.ndarray], path: Path) -> None: + np.savez(_get_output_item_path(path, 'predictions.npz', False), **predictions) + + +def dump_metrics(metrics: Dict[str, Any], path: Path) -> None: + dump_json(metrics, _get_output_item_path(path, 'metrics.json', False)) + + +def load_checkpoint(path: Path, *args, **kwargs) -> Dict[str, np.ndarray]: + return torch.load( + _get_output_item_path(path, 'checkpoint.pt', True), *args, **kwargs + ) + + +def get_device() -> torch.device: + if torch.cuda.is_available(): + assert os.environ.get('CUDA_VISIBLE_DEVICES') is not None + return torch.device('cuda:0') + else: + return torch.device('cpu') + + +def _print_sep(c, size=100): + print(c * size) + + +def start( + config_cls: Type[T] = RawConfig, + argv: Optional[List[str]] = None, + patch_raw_config: Optional[Callable[[RawConfig], None]] = None, +) -> Tuple[T, Path, Report]: # config # output dir # report + parser = argparse.ArgumentParser() + parser.add_argument('config', metavar='FILE') + parser.add_argument('--force', action='store_true') + parser.add_argument('--continue', action='store_true', dest='continue_') + if argv is None: + program = __main__.__file__ + args = parser.parse_args() + else: + program = argv[0] + try: + args = parser.parse_args(argv[1:]) + except Exception: + print( + 'Failed to parse `argv`.' + ' Remember that the first item of `argv` must be the path (relative to' + ' the project root) to the script/notebook.' + ) + raise + args = parser.parse_args(argv) + + snapshot_dir = os.environ.get('SNAPSHOT_PATH') + if snapshot_dir and Path(snapshot_dir).joinpath('CHECKPOINTS_RESTORED').exists(): + assert args.continue_ + + config_path = env.get_path(args.config) + output_dir = config_path.with_suffix('') + _print_sep('=') + print(f'[output] {output_dir}') + _print_sep('=') + + assert config_path.exists() + raw_config = load_config(config_path) + if patch_raw_config is not None: + patch_raw_config(raw_config) + if is_dataclass(config_cls): + config = from_dict(config_cls, raw_config) + full_raw_config = asdict(config) + else: + assert config_cls is dict + full_raw_config = config = raw_config + full_raw_config = asdict(config) + + if output_dir.exists(): + if args.force: + print('Removing the existing output and creating a new one...') + shutil.rmtree(output_dir) + output_dir.mkdir() + elif not args.continue_: + backup_output(output_dir) + print('The output directory already exists. Done!\n') + sys.exit() + elif output_dir.joinpath('DONE').exists(): + backup_output(output_dir) + print('The "DONE" file already exists. Done!') + sys.exit() + else: + print('Continuing with the existing output...') + else: + print('Creating the output...') + output_dir.mkdir() + + report = { + 'program': str(env.get_relative_path(program)), + 'environment': {}, + 'config': full_raw_config, + } + if torch.cuda.is_available(): # type: ignore[code] + report['environment'].update( + { + 'CUDA_VISIBLE_DEVICES': os.environ.get('CUDA_VISIBLE_DEVICES'), + 'gpus': zero.hardware.get_gpus_info(), + 'torch.version.cuda': torch.version.cuda, + 'torch.backends.cudnn.version()': torch.backends.cudnn.version(), # type: ignore[code] + 'torch.cuda.nccl.version()': torch.cuda.nccl.version(), # type: ignore[code] + } + ) + dump_report(report, output_dir) + dump_json(raw_config, output_dir / 'raw_config.json') + _print_sep('-') + pprint(full_raw_config, width=100) + _print_sep('-') + return cast(config_cls, config), output_dir, report + + +_LAST_SNAPSHOT_TIME = None + + +def backup_output(output_dir: Path) -> None: + backup_dir = os.environ.get('TMP_OUTPUT_PATH') + snapshot_dir = os.environ.get('SNAPSHOT_PATH') + if backup_dir is None: + assert snapshot_dir is None + return + assert snapshot_dir is not None + + try: + relative_output_dir = output_dir.relative_to(env.PROJ) + except ValueError: + return + + for dir_ in [backup_dir, snapshot_dir]: + new_output_dir = dir_ / relative_output_dir + prev_backup_output_dir = new_output_dir.with_name(new_output_dir.name + '_prev') + new_output_dir.parent.mkdir(exist_ok=True, parents=True) + if new_output_dir.exists(): + new_output_dir.rename(prev_backup_output_dir) + shutil.copytree(output_dir, new_output_dir) + # the case for evaluate.py which automatically creates configs + if output_dir.with_suffix('.toml').exists(): + shutil.copyfile( + output_dir.with_suffix('.toml'), new_output_dir.with_suffix('.toml') + ) + if prev_backup_output_dir.exists(): + shutil.rmtree(prev_backup_output_dir) + + global _LAST_SNAPSHOT_TIME + if _LAST_SNAPSHOT_TIME is None or time.time() - _LAST_SNAPSHOT_TIME > 10 * 60: + import nirvana_dl.snapshot # type: ignore[code] + + nirvana_dl.snapshot.dump_snapshot() + _LAST_SNAPSHOT_TIME = time.time() + print('The snapshot was saved!') + + +def _get_scores(metrics: Dict[str, Dict[str, Any]]) -> Optional[Dict[str, float]]: + return ( + {k: v['score'] for k, v in metrics.items()} + if 'score' in next(iter(metrics.values())) + else None + ) + + +def format_scores(metrics: Dict[str, Dict[str, Any]]) -> str: + return ' '.join( + f"[{x}] {metrics[x]['score']:.3f}" + for x in ['test', 'val', 'train'] + if x in metrics + ) + + +def finish(output_dir: Path, report: dict) -> None: + print() + _print_sep('=') + + metrics = report.get('metrics') + if metrics is not None: + scores = _get_scores(metrics) + if scores is not None: + dump_json(scores, output_dir / 'scores.json') + print(format_scores(metrics)) + _print_sep('-') + + dump_report(report, output_dir) + json_output_path = os.environ.get('JSON_OUTPUT_FILE') + if json_output_path: + try: + key = str(output_dir.relative_to(env.PROJ)) + except ValueError: + pass + else: + json_output_path = Path(json_output_path) + try: + json_data = json.loads(json_output_path.read_text()) + except (FileNotFoundError, json.decoder.JSONDecodeError): + json_data = {} + json_data[key] = load_json(output_dir / 'report.json') + json_output_path.write_text(json.dumps(json_data, indent=4)) + shutil.copyfile( + json_output_path, + os.path.join(os.environ['SNAPSHOT_PATH'], 'json_output.json'), + ) + + output_dir.joinpath('DONE').touch() + backup_output(output_dir) + print(f'Done! | {report.get("time")} | {output_dir}') + _print_sep('=') + print() + + +def from_dict(datacls: Type[T], data: dict) -> T: + assert is_dataclass(datacls) + data = deepcopy(data) + for field in fields(datacls): + if field.name not in data: + continue + if is_dataclass(field.type): + data[field.name] = from_dict(field.type, data[field.name]) + elif ( + get_origin(field.type) is Union + and len(get_args(field.type)) == 2 + and get_args(field.type)[1] is type(None) + and is_dataclass(get_args(field.type)[0]) + ): + if data[field.name] is not None: + data[field.name] = from_dict(get_args(field.type)[0], data[field.name]) + return datacls(**data) + + +def replace_factor_with_value( + config: RawConfig, + key: str, + reference_value: int, + bounds: Tuple[float, float], +) -> None: + factor_key = key + '_factor' + if factor_key not in config: + assert key in config + else: + assert key not in config + factor = config.pop(factor_key) + assert bounds[0] <= factor <= bounds[1] + config[key] = int(factor * reference_value) + + +def get_temporary_copy(path: Union[str, Path]) -> Path: + path = env.get_path(path) + assert not path.is_dir() and not path.is_symlink() + tmp_path = path.with_name( + path.stem + '___' + str(uuid.uuid4()).replace('-', '') + path.suffix + ) + shutil.copyfile(path, tmp_path) + atexit.register(lambda: tmp_path.unlink()) + return tmp_path + + +def get_python(): + python = Path('python3.9') + return str(python) if python.exists() else 'python' + +def get_catboost_config(real_data_path, is_cv=False): + ds_name = Path(real_data_path).name + C = load_json(f'tuned_models/catboost/{ds_name}_cv.json') + return C \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.pipeline.py b/src/synthcity/plugins/core/models/tabular_ddpm/.pipeline.py new file mode 100644 index 00000000..f6855f6b --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.pipeline.py @@ -0,0 +1,80 @@ +import tomli +import shutil +import os +import argparse +from train import train +from sample import sample +import pandas as pd +import matplotlib.pyplot as plt +import zero +import lib +import torch + +def load_config(path) : + with open(path, 'rb') as f: + return tomli.load(f) + +def save_file(parent_dir, config_path): + try: + dst = os.path.join(parent_dir) + os.makedirs(os.path.dirname(dst), exist_ok=True) + shutil.copyfile(os.path.abspath(config_path), dst) + except shutil.SameFileError: + pass + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', metavar='FILE') + parser.add_argument('--train', action='store_true', default=False) + parser.add_argument('--sample', action='store_true', default=False) + parser.add_argument('--eval', action='store_true', default=False) + parser.add_argument('--change_val', action='store_true', default=False) + + args = parser.parse_args() + raw_config = lib.load_config(args.config) + if 'device' in raw_config: + device = torch.device(raw_config['device']) + else: + device = torch.device('cuda:1') + + timer = zero.Timer() + timer.run() + save_file(os.path.join(raw_config['parent_dir'], 'config.toml'), args.config) + + if args.train: + train( + **raw_config['train']['main'], + **raw_config['diffusion_params'], + parent_dir=raw_config['parent_dir'], + real_data_path=raw_config['real_data_path'], + model_type=raw_config['model_type'], + model_params=raw_config['model_params'], + T_dict=raw_config['train']['T'], + num_numerical_features=raw_config['num_numerical_features'], + device=device, + change_val=args.change_val + ) + if args.sample: + sample( + num_samples=raw_config['sample']['num_samples'], + batch_size=raw_config['sample']['batch_size'], + disbalance=raw_config['sample'].get('disbalance', None), + **raw_config['diffusion_params'], + parent_dir=raw_config['parent_dir'], + real_data_path=raw_config['real_data_path'], + model_path=os.path.join(raw_config['parent_dir'], 'model.pt'), + model_type=raw_config['model_type'], + model_params=raw_config['model_params'], + T_dict=raw_config['train']['T'], + num_numerical_features=raw_config['num_numerical_features'], + device=device, + seed=raw_config['sample'].get('seed', 0), + change_val=args.change_val + ) + + save_file(os.path.join(raw_config['parent_dir'], 'info.json'), os.path.join(raw_config['real_data_path'], 'info.json')) + + print(f'Elapsed time: {str(timer)}') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.sample.py b/src/synthcity/plugins/core/models/tabular_ddpm/.sample.py new file mode 100644 index 00000000..abc68162 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.sample.py @@ -0,0 +1,159 @@ +import torch +import numpy as np +import zero +import os +from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion +from .utils import FoundNANsError +from utils_train import get_model, make_dataset +from .lib import round_columns +import lib + +def to_good_ohe(ohe, X): + indices = np.cumsum([0] + ohe._n_features_outs) + Xres = [] + for i in range(1, len(indices)): + x_ = np.max(X[:, indices[i - 1]:indices[i]], axis=1) + t = X[:, indices[i - 1]:indices[i]] - x_.reshape(-1, 1) + Xres.append(np.where(t >= 0, 1, 0)) + return np.hstack(Xres) + +def sample( + parent_dir, + real_data_path = 'data/higgs-small', + batch_size = 2000, + num_samples = 0, + model_type = 'mlp', + model_params = None, + model_path = None, + num_timesteps = 1000, + gaussian_loss_type = 'mse', + scheduler = 'cosine', + T_dict = None, + num_numerical_features = 0, + disbalance = None, + device = torch.device('cuda:1'), + seed = 0, + change_val = False +): + zero.improve_reproducibility(seed) + + T = lib.Transformations(**T_dict) + D = make_dataset( + real_data_path, + T, + num_classes=model_params['num_classes'], + is_y_cond=model_params['is_y_cond'], + change_val=change_val + ) + + K = np.array(D.get_category_sizes('train')) + if len(K) == 0 or T_dict['cat_encoding'] == 'one-hot': + K = np.array([0]) + + num_numerical_features_ = D.X_num['train'].shape[1] if D.X_num is not None else 0 + d_in = np.sum(K) + num_numerical_features_ + model_params['d_in'] = int(d_in) + model = get_model( + model_type, + model_params, + num_numerical_features_, + category_sizes=D.get_category_sizes('train') + ) + + model.load_state_dict( + torch.load(model_path, map_location="cpu") + ) + + diffusion = GaussianMultinomialDiffusion( + K, + num_numerical_features=num_numerical_features_, + denoise_fn=model, num_timesteps=num_timesteps, + gaussian_loss_type=gaussian_loss_type, scheduler=scheduler, device=device + ) + + diffusion.to(device) + diffusion.eval() + + _, empirical_class_dist = torch.unique(torch.from_numpy(D.y['train']), return_counts=True) + # empirical_class_dist = empirical_class_dist.float() + torch.tensor([-5000., 10000.]).float() + if disbalance == 'fix': + empirical_class_dist[0], empirical_class_dist[1] = empirical_class_dist[1], empirical_class_dist[0] + x_gen, y_gen = diffusion.sample_all(num_samples, batch_size, empirical_class_dist.float(), ddim=False) + + elif disbalance == 'fill': + ix_major = empirical_class_dist.argmax().item() + val_major = empirical_class_dist[ix_major].item() + x_gen, y_gen = [], [] + for i in range(empirical_class_dist.shape[0]): + if i == ix_major: + continue + distrib = torch.zeros_like(empirical_class_dist) + distrib[i] = 1 + num_samples = val_major - empirical_class_dist[i].item() + x_temp, y_temp = diffusion.sample_all(num_samples, batch_size, distrib.float(), ddim=False) + x_gen.append(x_temp) + y_gen.append(y_temp) + + x_gen = torch.cat(x_gen, dim=0) + y_gen = torch.cat(y_gen, dim=0) + + else: + x_gen, y_gen = diffusion.sample_all(num_samples, batch_size, empirical_class_dist.float(), ddim=False) + + + # try: + # except FoundNANsError as ex: + # print("Found NaNs during sampling!") + # loader = lib.prepare_fast_dataloader(D, 'train', 8) + # x_gen = next(loader)[0] + # y_gen = torch.multinomial( + # empirical_class_dist.float(), + # num_samples=8, + # replacement=True + # ) + X_gen, y_gen = x_gen.numpy(), y_gen.numpy() + + ### + # X_num_unnorm = X_gen[:, :num_numerical_features] + # lo = np.percentile(X_num_unnorm, 2.5, axis=0) + # hi = np.percentile(X_num_unnorm, 97.5, axis=0) + # idx = (lo < X_num_unnorm) & (hi > X_num_unnorm) + # X_gen = X_gen[np.all(idx, axis=1)] + # y_gen = y_gen[np.all(idx, axis=1)] + ### + + num_numerical_features = num_numerical_features + int(D.is_regression and not model_params["is_y_cond"]) + + X_num_ = X_gen + if num_numerical_features < X_gen.shape[1]: + np.save(os.path.join(parent_dir, 'X_cat_unnorm'), X_gen[:, num_numerical_features:]) + # _, _, cat_encoder = lib.cat_encode({'train': X_cat_real}, T_dict['cat_encoding'], y_real, T_dict['seed'], True) + if T_dict['cat_encoding'] == 'one-hot': + X_gen[:, num_numerical_features:] = to_good_ohe(D.cat_transform.steps[0][1], X_num_[:, num_numerical_features:]) + X_cat = D.cat_transform.inverse_transform(X_gen[:, num_numerical_features:]) + + if num_numerical_features_ != 0: + # _, normalize = lib.normalize({'train' : X_num_real}, T_dict['normalization'], T_dict['seed'], True) + np.save(os.path.join(parent_dir, 'X_num_unnorm'), X_gen[:, :num_numerical_features]) + X_num_ = D.num_transform.inverse_transform(X_gen[:, :num_numerical_features]) + X_num = X_num_[:, :num_numerical_features] + + X_num_real = np.load(os.path.join(real_data_path, "X_num_train.npy"), allow_pickle=True) + disc_cols = [] + for col in range(X_num_real.shape[1]): + uniq_vals = np.unique(X_num_real[:, col]) + if len(uniq_vals) <= 32 and ((uniq_vals - np.round(uniq_vals)) == 0).all(): + disc_cols.append(col) + print("Discrete cols:", disc_cols) + if model_params['num_classes'] == 0: + y_gen = X_num[:, 0] + X_num = X_num[:, 1:] + if len(disc_cols): + X_num = round_columns(X_num_real, X_num, disc_cols) + + if num_numerical_features != 0: + print("Num shape: ", X_num.shape) + np.save(os.path.join(parent_dir, 'X_num_train'), X_num) + if num_numerical_features < X_gen.shape[1]: + np.save(os.path.join(parent_dir, 'X_cat_train'), X_cat) + np.save(os.path.join(parent_dir, 'y_train'), y_gen) \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.train.py b/src/synthcity/plugins/core/models/tabular_ddpm/.train.py new file mode 100644 index 00000000..85cac744 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.train.py @@ -0,0 +1,156 @@ +from copy import deepcopy +import torch +import os +import numpy as np +import zero +from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion +from utils_train import get_model, make_dataset, update_ema +from . import lib +import pandas as pd + +class Trainer: + def __init__(self, diffusion, train_iter, lr, weight_decay, steps, device=torch.device('cuda:1')): + self.diffusion = diffusion + self.ema_model = deepcopy(self.diffusion._denoise_fn) + for param in self.ema_model.parameters(): + param.detach_() + + self.train_iter = train_iter + self.steps = steps + self.init_lr = lr + self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay) + self.device = device + self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) + self.log_every = 100 + self.print_every = 500 + self.ema_every = 1000 + + def _anneal_lr(self, step): + frac_done = step / self.steps + lr = self.init_lr * (1 - frac_done) + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + def _run_step(self, x, out_dict): + x = x.to(self.device) + for k in out_dict: + out_dict[k] = out_dict[k].long().to(self.device) + self.optimizer.zero_grad() + loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict) + loss = loss_multi + loss_gauss + loss.backward() + self.optimizer.step() + + return loss_multi, loss_gauss + + def run_loop(self): + step = 0 + curr_loss_multi = 0.0 + curr_loss_gauss = 0.0 + + curr_count = 0 + while step < self.steps: + x, out_dict = next(self.train_iter) + out_dict = {'y': out_dict} + batch_loss_multi, batch_loss_gauss = self._run_step(x, out_dict) + + self._anneal_lr(step) + + curr_count += len(x) + curr_loss_multi += batch_loss_multi.item() * len(x) + curr_loss_gauss += batch_loss_gauss.item() * len(x) + + if (step + 1) % self.log_every == 0: + mloss = np.around(curr_loss_multi / curr_count, 4) + gloss = np.around(curr_loss_gauss / curr_count, 4) + if (step + 1) % self.print_every == 0: + print(f'Step {(step + 1)}/{self.steps} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') + self.loss_history.loc[len(self.loss_history)] =[step + 1, mloss, gloss, mloss + gloss] + curr_count = 0 + curr_loss_gauss = 0.0 + curr_loss_multi = 0.0 + + update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters()) + + step += 1 + +def train( + parent_dir, + real_data_path = 'data/higgs-small', + steps = 1000, + lr = 0.002, + weight_decay = 1e-4, + batch_size = 1024, + model_type = 'mlp', + model_params = None, + num_timesteps = 1000, + gaussian_loss_type = 'mse', + scheduler = 'cosine', + T_dict = None, + num_numerical_features = 0, + device = torch.device('cuda:1'), + seed = 0, + change_val = False +): + real_data_path = os.path.normpath(real_data_path) + parent_dir = os.path.normpath(parent_dir) + + zero.improve_reproducibility(seed) + + T = lib.Transformations(**T_dict) + + dataset = make_dataset( + real_data_path, + T, + num_classes=model_params['num_classes'], + is_y_cond=model_params['is_y_cond'], + change_val=change_val + ) + + K = np.array(dataset.get_category_sizes('train')) + if len(K) == 0 or T_dict['cat_encoding'] == 'one-hot': + K = np.array([0]) + print(K) + + num_numerical_features = dataset.X_num['train'].shape[1] if dataset.X_num is not None else 0 + d_in = np.sum(K) + num_numerical_features + model_params['d_in'] = d_in + print(d_in) + + print(model_params) + model = get_model( + model_type, + model_params, + num_numerical_features, + category_sizes=dataset.get_category_sizes('train') + ) + model.to(device) + + # train_loader = lib.prepare_beton_loader(dataset, split='train', batch_size=batch_size) + train_loader = lib.prepare_fast_dataloader(dataset, split='train', batch_size=batch_size) + + diffusion = GaussianMultinomialDiffusion( + num_classes=K, + num_numerical_features=num_numerical_features, + denoise_fn=model, + gaussian_loss_type=gaussian_loss_type, + num_timesteps=num_timesteps, + scheduler=scheduler, + device=device + ) + diffusion.to(device) + diffusion.train() + + trainer = Trainer( + diffusion, + train_loader, + lr=lr, + weight_decay=weight_decay, + steps=steps, + device=device + ) + trainer.run_loop() + + trainer.loss_history.to_csv(os.path.join(parent_dir, 'loss.csv'), index=False) + torch.save(diffusion._denoise_fn.state_dict(), os.path.join(parent_dir, 'model.pt')) + torch.save(trainer.ema_model.state_dict(), os.path.join(parent_dir, 'model_ema.pt')) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.tune.py b/src/synthcity/plugins/core/models/tabular_ddpm/.tune.py new file mode 100644 index 00000000..5a95dc23 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.tune.py @@ -0,0 +1,127 @@ +import subprocess +import lib +import os +import optuna +from copy import deepcopy +import shutil +import argparse +from pathlib import Path + +parser = argparse.ArgumentParser() +parser.add_argument('ds_name', type=str) +parser.add_argument('train_size', type=int) +parser.add_argument('eval_type', type=str) +parser.add_argument('eval_model', type=str) +parser.add_argument('prefix', type=str) +parser.add_argument('--eval_seeds', action='store_true', default=False) + +args = parser.parse_args() +train_size = args.train_size +ds_name = args.ds_name +eval_type = args.eval_type +assert eval_type in ('merged', 'synthetic') +prefix = str(args.prefix) + +pipeline = f'scripts/pipeline.py' +base_config_path = f'exp/{ds_name}/config.toml' +parent_path = Path(f'exp/{ds_name}/') +exps_path = Path(f'exp/{ds_name}/many-exps/') # temporary dir. maybe will be replaced with tempdiвdr +eval_seeds = f'scripts/eval_seeds.py' + +os.makedirs(exps_path, exist_ok=True) + +def _suggest_mlp_layers(trial): + def suggest_dim(name): + t = trial.suggest_int(name, d_min, d_max) + return 2 ** t + min_n_layers, max_n_layers, d_min, d_max = 1, 4, 7, 10 + n_layers = 2 * trial.suggest_int('n_layers', min_n_layers, max_n_layers) + d_first = [suggest_dim('d_first')] if n_layers else [] + d_middle = ( + [suggest_dim('d_middle')] * (n_layers - 2) + if n_layers > 2 + else [] + ) + d_last = [suggest_dim('d_last')] if n_layers > 1 else [] + d_layers = d_first + d_middle + d_last + return d_layers + +def objective(trial): + + lr = trial.suggest_loguniform('lr', 0.00001, 0.003) + d_layers = _suggest_mlp_layers(trial) + weight_decay = 0.0 + batch_size = trial.suggest_categorical('batch_size', [256, 4096]) + steps = trial.suggest_categorical('steps', [5000, 20000, 30000]) + # steps = trial.suggest_categorical('steps', [500]) # for debug + gaussian_loss_type = 'mse' + # scheduler = trial.suggest_categorical('scheduler', ['cosine', 'linear']) + num_timesteps = trial.suggest_categorical('num_timesteps', [100, 1000]) + num_samples = int(train_size * (2 ** trial.suggest_int('num_samples', -2, 1))) + + base_config = lib.load_config(base_config_path) + + base_config['train']['main']['lr'] = lr + base_config['train']['main']['steps'] = steps + base_config['train']['main']['batch_size'] = batch_size + base_config['train']['main']['weight_decay'] = weight_decay + base_config['model_params']['rtdl_params']['d_layers'] = d_layers + base_config['eval']['type']['eval_type'] = eval_type + base_config['sample']['num_samples'] = num_samples + base_config['diffusion_params']['gaussian_loss_type'] = gaussian_loss_type + base_config['diffusion_params']['num_timesteps'] = num_timesteps + # base_config['diffusion_params']['scheduler'] = scheduler + + base_config['parent_dir'] = str(exps_path / f"{trial.number}") + base_config['eval']['type']['eval_model'] = args.eval_model + if args.eval_model == "mlp": + base_config['eval']['T']['normalization'] = "quantile" + base_config['eval']['T']['cat_encoding'] = "one-hot" + + trial.set_user_attr("config", base_config) + + lib.dump_config(base_config, exps_path / 'config.toml') + + subprocess.run(['python3.9', f'{pipeline}', '--config', f'{exps_path / "config.toml"}', '--train', '--change_val'], check=True) + + n_datasets = 5 + score = 0.0 + + for sample_seed in range(n_datasets): + base_config['sample']['seed'] = sample_seed + lib.dump_config(base_config, exps_path / 'config.toml') + + subprocess.run(['python3.9', f'{pipeline}', '--config', f'{exps_path / "config.toml"}', '--sample', '--eval', '--change_val'], check=True) + + report_path = str(Path(base_config['parent_dir']) / f'results_{args.eval_model}.json') + report = lib.load_json(report_path) + + if 'r2' in report['metrics']['val']: + score += report['metrics']['val']['r2'] + else: + score += report['metrics']['val']['macro avg']['f1-score'] + + shutil.rmtree(exps_path / f"{trial.number}") + + return score / n_datasets + +study = optuna.create_study( + direction='maximize', + sampler=optuna.samplers.TPESampler(seed=0), +) + +study.optimize(objective, n_trials=50, show_progress_bar=True) + +best_config_path = parent_path / f'{prefix}_best/config.toml' +best_config = study.best_trial.user_attrs['config'] +best_config["parent_dir"] = str(parent_path / f'{prefix}_best/') + +os.makedirs(parent_path / f'{prefix}_best', exist_ok=True) +lib.dump_config(best_config, best_config_path) +lib.dump_json(optuna.importance.get_param_importances(study), parent_path / f'{prefix}_best/importance.json') + +subprocess.run(['python3.9', f'{pipeline}', '--config', f'{best_config_path}', '--train', '--sample'], check=True) + +if args.eval_seeds: + best_exp = str(parent_path / f'{prefix}_best/config.toml') + subprocess.run(['python3.9', f'{eval_seeds}', '--config', f'{best_exp}', '10', "ddpm", eval_type, args.eval_model, '5'], check=True) \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.utils_train.py b/src/synthcity/plugins/core/models/tabular_ddpm/.utils_train.py new file mode 100644 index 00000000..3062b15d --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/.utils_train.py @@ -0,0 +1,88 @@ +import numpy as np +import os +import lib +from .modules import MLPDiffusion, ResNetDiffusion + +def get_model( + model_name, + model_params, + n_num_features, + category_sizes +): + if model_name == 'mlp': + model = MLPDiffusion(**model_params) + elif model_name == 'resnet': + model = ResNetDiffusion(**model_params) + else: + raise "Unknown model!" + return model + +def update_ema(target_params, source_params, rate=0.999): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) + +def concat_y_to_X(X, y): + if X is None: + return y.reshape(-1, 1) + return np.concatenate([y.reshape(-1, 1), X], axis=1) + +def make_dataset( + data_path: str, + T: lib.Transformations, + num_classes: int, + is_y_cond: bool, + change_val: bool +): + # classification + if num_classes > 0: + X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) or not is_y_cond else None + X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) else None + y = {} + + for split in ['train', 'val', 'test']: + X_num_t, X_cat_t, y_t = lib.read_pure_data(data_path, split) + if X_num is not None: + X_num[split] = X_num_t + if not is_y_cond: + X_cat_t = concat_y_to_X(X_cat_t, y_t) + if X_cat is not None: + X_cat[split] = X_cat_t + y[split] = y_t + else: + # regression + X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) else None + X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) or not is_y_cond else None + y = {} + + for split in ['train', 'val', 'test']: + X_num_t, X_cat_t, y_t = lib.read_pure_data(data_path, split) + if not is_y_cond: + X_num_t = concat_y_to_X(X_num_t, y_t) + if X_num is not None: + X_num[split] = X_num_t + if X_cat is not None: + X_cat[split] = X_cat_t + y[split] = y_t + + info = lib.load_json(os.path.join(data_path, 'info.json')) + + D = lib.Dataset( + X_num, + X_cat, + y, + y_info={}, + task_type=lib.TaskType(info['task_type']), + n_classes=info.get('n_classes') + ) + + if change_val: + D = lib.change_val(D) + + return lib.transform_dataset(D, T, None) \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/README.md b/src/synthcity/plugins/core/models/tabular_ddpm/README.md new file mode 100644 index 00000000..3d418685 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/README.md @@ -0,0 +1,3 @@ +# TabDDPM: Modelling Tabular Data with Diffusion Models + +Adapted from https://github.com/rotot0/tab-ddpm. diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py new file mode 100644 index 00000000..80d346c2 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -0,0 +1,2 @@ +from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion # noqa +from .modules import MLPDiffusion, ResNetDiffusion # noqa \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py new file mode 100644 index 00000000..0d0f2ce4 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -0,0 +1,992 @@ +""" +Based on https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +and https://github.com/ehoogeboom/multinomial_diffusion +""" + +import torch.nn.functional as F +import torch +import math + +import numpy as np +from .utils import * + +""" +Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 +""" +eps = 1e-8 + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + +class GaussianMultinomialDiffusion(torch.nn.Module): + def __init__( + self, + num_classes: np.array, + num_numerical_features: int, + denoise_fn, + num_timesteps=1000, + gaussian_loss_type='mse', + gaussian_parametrization='eps', + multinomial_loss_type='vb_stochastic', + parametrization='x0', + scheduler='cosine', + device=torch.device('cpu') + ): + + super(GaussianMultinomialDiffusion, self).__init__() + assert multinomial_loss_type in ('vb_stochastic', 'vb_all') + assert parametrization in ('x0', 'direct') + + if multinomial_loss_type == 'vb_all': + print('Computing the loss using the bound on _all_ timesteps.' + ' This is expensive both in terms of memory and computation.') + + self.num_numerical_features = num_numerical_features + self.num_classes = num_classes # it as a vector [K1, K2, ..., Km] + self.num_classes_expanded = torch.from_numpy( + np.concatenate([num_classes[i].repeat(num_classes[i]) for i in range(len(num_classes))]) + ).to(device) + + self.slices_for_classes = [np.arange(self.num_classes[0])] + offsets = np.cumsum(self.num_classes) + for i in range(1, len(offsets)): + self.slices_for_classes.append(np.arange(offsets[i - 1], offsets[i])) + self.offsets = torch.from_numpy(np.append([0], offsets)).to(device) + + self._denoise_fn = denoise_fn + self.gaussian_loss_type = gaussian_loss_type + self.gaussian_parametrization = gaussian_parametrization + self.multinomial_loss_type = multinomial_loss_type + self.num_timesteps = num_timesteps + self.parametrization = parametrization + self.scheduler = scheduler + + alphas = 1. - get_named_beta_schedule(scheduler, num_timesteps) + alphas = torch.tensor(alphas.astype('float64')) + betas = 1. - alphas + + log_alpha = np.log(alphas) + log_cumprod_alpha = np.cumsum(log_alpha) + + log_1_min_alpha = log_1_min_a(log_alpha) + log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) + + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = torch.tensor(np.append(1.0, alphas_cumprod[:-1])) + alphas_cumprod_next = torch.tensor(np.append(alphas_cumprod[1:], 0.0)) + sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) + sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) + sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) + + # Gaussian diffusion + + self.posterior_variance = ( + betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + self.posterior_log_variance_clipped = torch.from_numpy( + np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + ).float().to(device) + self.posterior_mean_coef1 = ( + betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ).float().to(device) + self.posterior_mean_coef2 = ( + (1.0 - alphas_cumprod_prev) + * np.sqrt(alphas.numpy()) + / (1.0 - alphas_cumprod) + ).float().to(device) + + assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5 + assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 + assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5 + + # Convert to float32 and register buffers. + self.register_buffer('alphas', alphas.float().to(device)) + self.register_buffer('log_alpha', log_alpha.float().to(device)) + self.register_buffer('log_1_min_alpha', log_1_min_alpha.float().to(device)) + self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float().to(device)) + self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float().to(device)) + self.register_buffer('alphas_cumprod', alphas_cumprod.float().to(device)) + self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float().to(device)) + self.register_buffer('alphas_cumprod_next', alphas_cumprod_next.float().to(device)) + self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod.float().to(device)) + self.register_buffer('sqrt_one_minus_alphas_cumprod', sqrt_one_minus_alphas_cumprod.float().to(device)) + self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod.float().to(device)) + self.register_buffer('sqrt_recipm1_alphas_cumprod', sqrt_recipm1_alphas_cumprod.float().to(device)) + + self.register_buffer('Lt_history', torch.zeros(num_timesteps)) + self.register_buffer('Lt_count', torch.zeros(num_timesteps)) + + # Gaussian part + def gaussian_q_mean_variance(self, x_start, t): + mean = ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract( + self.log_1_min_cumprod_alpha, t, x_start.shape + ) + return mean, variance, log_variance + + def gaussian_q_sample(self, x_start, t, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def gaussian_q_posterior_mean_variance(self, x_start, x_t, t): + assert x_start.shape == x_t.shape + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def gaussian_p_mean_variance( + self, model_output, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None + ): + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + + model_variance = torch.cat([self.posterior_variance[1].unsqueeze(0).to(x.device), (1. - self.alphas)[1:]], dim=0) + # model_variance = self.posterior_variance.to(x.device) + model_log_variance = torch.log(model_variance) + + model_variance = extract(model_variance, t, x.shape) + model_log_variance = extract(model_log_variance, t, x.shape) + + + if self.gaussian_parametrization == 'eps': + pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + elif self.gaussian_parametrization == 'x0': + pred_xstart = model_output + else: + raise NotImplementedError + + model_mean, _, _ = self.gaussian_q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ), f'{model_mean.shape}, {model_log_variance.shape}, {pred_xstart.shape}, {x.shape}' + + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _vb_terms_bpd( + self, model_output, x_start, x_t, t, clip_denoised=False, model_kwargs=None + ): + true_mean, _, true_log_variance_clipped = self.gaussian_q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.gaussian_p_mean_variance( + model_output, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = torch.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"], "out_mean": out["mean"], "true_mean": true_mean} + + def _prior_gaussian(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.gaussian_q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def _gaussian_loss(self, model_out, x_start, x_t, t, noise, model_kwargs=None): + if model_kwargs is None: + model_kwargs = {} + + terms = {} + if self.gaussian_loss_type == 'mse': + terms["loss"] = mean_flat((noise - model_out) ** 2) + elif self.gaussian_loss_type == 'kl': + terms["loss"] = self._vb_terms_bpd( + model_output=model_out, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + + + return terms['loss'] + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def gaussian_p_sample( + self, + model_out, + x, + t, + clip_denoised=False, + denoised_fn=None, + model_kwargs=None, + ): + out = self.gaussian_p_mean_variance( + model_out, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = torch.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + + sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + # Multinomial part + + def multinomial_kl(self, log_prob1, log_prob2): + kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1) + return kl + + def q_pred_one_timestep(self, log_x_t, t): + log_alpha_t = extract(self.log_alpha, t, log_x_t.shape) + log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape) + + # alpha_t * E[xt] + (1 - alpha_t) 1 / K + log_probs = log_add_exp( + log_x_t + log_alpha_t, + log_1_min_alpha_t - torch.log(self.num_classes_expanded) + ) + + return log_probs + + def q_pred(self, log_x_start, t): + log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape) + log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape) + + log_probs = log_add_exp( + log_x_start + log_cumprod_alpha_t, + log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded) + ) + + return log_probs + + def predict_start(self, model_out, log_x_t, t, out_dict): + + # model_out = self._denoise_fn(x_t, t.to(x_t.device), **out_dict) + + assert model_out.size(0) == log_x_t.size(0) + assert model_out.size(1) == self.num_classes.sum(), f'{model_out.size()}' + + log_pred = torch.empty_like(model_out) + for ix in self.slices_for_classes: + log_pred[:, ix] = F.log_softmax(model_out[:, ix], dim=1) + return log_pred + + def q_posterior(self, log_x_start, log_x_t, t): + # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0) + # where q(xt | xt-1, x0) = q(xt | xt-1). + + # EV_log_qxt_x0 = self.q_pred(log_x_start, t) + + # print('sum exp', EV_log_qxt_x0.exp().sum(1).mean()) + # assert False + + # log_qxt_x0 = (log_x_t.exp() * EV_log_qxt_x0).sum(dim=1) + t_minus_1 = t - 1 + # Remove negative values, will not be used anyway for final decoder + t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) + log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) + + num_axes = (1,) * (len(log_x_start.size()) - 1) + t_broadcast = t.to(log_x_start.device).view(-1, *num_axes) * torch.ones_like(log_x_start) + log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0.to(torch.float32)) + + # unnormed_logprobs = log_EV_qxtmin_x0 + + # log q_pred_one_timestep(x_t, t) + # Note: _NOT_ x_tmin1, which is how the formula is typically used!!! + # Not very easy to see why this is true. But it is :) + unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) + + log_EV_xtmin_given_xt_given_xstart = \ + unnormed_logprobs \ + - sliced_logsumexp(unnormed_logprobs, self.offsets) + + return log_EV_xtmin_given_xt_given_xstart + + def p_pred(self, model_out, log_x, t, out_dict): + if self.parametrization == 'x0': + log_x_recon = self.predict_start(model_out, log_x, t=t, out_dict=out_dict) + log_model_pred = self.q_posterior( + log_x_start=log_x_recon, log_x_t=log_x, t=t) + elif self.parametrization == 'direct': + log_model_pred = self.predict_start(model_out, log_x, t=t, out_dict=out_dict) + else: + raise ValueError + return log_model_pred + + @torch.no_grad() + def p_sample(self, model_out, log_x, t, out_dict): + model_log_prob = self.p_pred(model_out, log_x=log_x, t=t, out_dict=out_dict) + out = self.log_sample_categorical(model_log_prob) + return out + + @torch.no_grad() + def p_sample_loop(self, shape, out_dict): + device = self.log_alpha.device + + b = shape[0] + # start with random normal image. + img = torch.randn(shape, device=device) + + for i in reversed(range(1, self.num_timesteps)): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), out_dict) + return img + + @torch.no_grad() + def _sample(self, image_size, out_dict, batch_size = 16): + return self.p_sample_loop((batch_size, 3, image_size, image_size), out_dict) + + @torch.no_grad() + def interpolate(self, x1, x2, t = None, lam = 0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in reversed(range(0, t)): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) + + return img + + def log_sample_categorical(self, logits): + full_sample = [] + for i in range(len(self.num_classes)): + one_class_logits = logits[:, self.slices_for_classes[i]] + uniform = torch.rand_like(one_class_logits) + gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) + sample = (gumbel_noise + one_class_logits).argmax(dim=1) + full_sample.append(sample.unsqueeze(1)) + full_sample = torch.cat(full_sample, dim=1) + log_sample = index_to_log_onehot(full_sample, self.num_classes) + return log_sample + + def q_sample(self, log_x_start, t): + log_EV_qxt_x0 = self.q_pred(log_x_start, t) + + log_sample = self.log_sample_categorical(log_EV_qxt_x0) + + return log_sample + + def nll(self, log_x_start, out_dict): + b = log_x_start.size(0) + device = log_x_start.device + loss = 0 + for t in range(0, self.num_timesteps): + t_array = (torch.ones(b, device=device) * t).long() + + kl = self.compute_Lt( + log_x_start=log_x_start, + log_x_t=self.q_sample(log_x_start=log_x_start, t=t_array), + t=t_array, + out_dict=out_dict) + + loss += kl + + loss += self.kl_prior(log_x_start) + + return loss + + def kl_prior(self, log_x_start): + b = log_x_start.size(0) + device = log_x_start.device + ones = torch.ones(b, device=device).long() + + log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) + log_half_prob = -torch.log(self.num_classes_expanded * torch.ones_like(log_qxT_prob)) + + kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) + return sum_except_batch(kl_prior) + + def compute_Lt(self, model_out, log_x_start, log_x_t, t, out_dict, detach_mean=False): + log_true_prob = self.q_posterior( + log_x_start=log_x_start, log_x_t=log_x_t, t=t) + log_model_prob = self.p_pred(model_out, log_x=log_x_t, t=t, out_dict=out_dict) + + if detach_mean: + log_model_prob = log_model_prob.detach() + + kl = self.multinomial_kl(log_true_prob, log_model_prob) + kl = sum_except_batch(kl) + + decoder_nll = -log_categorical(log_x_start, log_model_prob) + decoder_nll = sum_except_batch(decoder_nll) + + mask = (t == torch.zeros_like(t)).float() + loss = mask * decoder_nll + (1. - mask) * kl + + return loss + + def sample_time(self, b, device, method='uniform'): + if method == 'importance': + if not (self.Lt_count > 10).all(): + return self.sample_time(b, device, method='uniform') + + Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 + Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. + pt_all = (Lt_sqrt / Lt_sqrt.sum()).to(device) + + t = torch.multinomial(pt_all, num_samples=b, replacement=True).to(device) + + pt = pt_all.gather(dim=0, index=t) + + return t, pt + + elif method == 'uniform': + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + + pt = torch.ones_like(t).float() / self.num_timesteps + return t, pt + else: + raise ValueError + + def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt, out_dict): + + if self.multinomial_loss_type == 'vb_stochastic': + kl = self.compute_Lt( + model_out, log_x_start, log_x_t, t, out_dict + ) + kl_prior = self.kl_prior(log_x_start) + # Upweigh loss term of the kl + vb_loss = kl / pt + kl_prior + + return vb_loss + + elif self.multinomial_loss_type == 'vb_all': + # Expensive, dont do it ;). + # DEPRECATED + return -self.nll(log_x_start) + else: + raise ValueError() + + def log_prob(self, x, out_dict): + b, device = x.size(0), x.device + if self.training: + return self._multinomial_loss(x, out_dict) + + else: + log_x_start = index_to_log_onehot(x, self.num_classes) + + t, pt = self.sample_time(b, device, 'importance') + + kl = self.compute_Lt( + log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t, out_dict) + + kl_prior = self.kl_prior(log_x_start) + + # Upweigh loss term of the kl + loss = kl / pt + kl_prior + + return -loss + + def mixed_loss(self, x, out_dict): + b = x.shape[0] + device = x.device + t, pt = self.sample_time(b, device, 'uniform') + + x_num = x[:, :self.num_numerical_features] + x_cat = x[:, self.num_numerical_features:] + + x_num_t = x_num + log_x_cat_t = x_cat + if x_num.shape[1] > 0: + noise = torch.randn_like(x_num) + x_num_t = self.gaussian_q_sample(x_num, t, noise=noise) + if x_cat.shape[1] > 0: + log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes) + log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t) + + x_in = torch.cat([x_num_t, log_x_cat_t], dim=1) + + model_out = self._denoise_fn( + x_in, + t, + **out_dict + ) + + model_out_num = model_out[:, :self.num_numerical_features] + model_out_cat = model_out[:, self.num_numerical_features:] + + loss_multi = torch.zeros((1,)).float() + loss_gauss = torch.zeros((1,)).float() + if x_cat.shape[1] > 0: + loss_multi = self._multinomial_loss(model_out_cat, log_x_cat, log_x_cat_t, t, pt, out_dict) / len(self.num_classes) + + if x_num.shape[1] > 0: + loss_gauss = self._gaussian_loss(model_out_num, x_num, x_num_t, t, noise) + + # loss_multi = torch.where(out_dict['y'] == 1, loss_multi, 2 * loss_multi) + # loss_gauss = torch.where(out_dict['y'] == 1, loss_gauss, 2 * loss_gauss) + + return loss_multi.mean(), loss_gauss.mean() + + @torch.no_grad() + def mixed_elbo(self, x0, out_dict): + b = x0.size(0) + device = x0.device + + x_num = x0[:, :self.num_numerical_features] + x_cat = x0[:, self.num_numerical_features:] + has_cat = x_cat.shape[1] > 0 + if has_cat: + log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes).to(device) + + gaussian_loss = [] + xstart_mse = [] + mse = [] + mu_mse = [] + out_mean = [] + true_mean = [] + multinomial_loss = [] + for t in range(self.num_timesteps): + t_array = (torch.ones(b, device=device) * t).long() + noise = torch.randn_like(x_num) + + x_num_t = self.gaussian_q_sample(x_start=x_num, t=t_array, noise=noise) + if has_cat: + log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t_array) + else: + log_x_cat_t = x_cat + + model_out = self._denoise_fn( + torch.cat([x_num_t, log_x_cat_t], dim=1), + t_array, + **out_dict + ) + + model_out_num = model_out[:, :self.num_numerical_features] + model_out_cat = model_out[:, self.num_numerical_features:] + + kl = torch.tensor([0.0]) + if has_cat: + kl = self.compute_Lt( + model_out=model_out_cat, + log_x_start=log_x_cat, + log_x_t=log_x_cat_t, + t=t_array, + out_dict=out_dict + ) + + out = self._vb_terms_bpd( + model_out_num, + x_start=x_num, + x_t=x_num_t, + t=t_array, + clip_denoised=False + ) + + multinomial_loss.append(kl) + gaussian_loss.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_num) ** 2)) + # mu_mse.append(mean_flat(out["mean_mse"])) + out_mean.append(mean_flat(out["out_mean"])) + true_mean.append(mean_flat(out["true_mean"])) + + eps = self._predict_eps_from_xstart(x_num_t, t_array, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + gaussian_loss = torch.stack(gaussian_loss, dim=1) + multinomial_loss = torch.stack(multinomial_loss, dim=1) + xstart_mse = torch.stack(xstart_mse, dim=1) + mse = torch.stack(mse, dim=1) + # mu_mse = torch.stack(mu_mse, dim=1) + out_mean = torch.stack(out_mean, dim=1) + true_mean = torch.stack(true_mean, dim=1) + + + prior_gauss = self._prior_gaussian(x_num) + + prior_multin = torch.tensor([0.0]) + if has_cat: + prior_multin = self.kl_prior(log_x_cat) + + total_gauss = gaussian_loss.sum(dim=1) + prior_gauss + total_multin = multinomial_loss.sum(dim=1) + prior_multin + return { + "total_gaussian": total_gauss, + "total_multinomial": total_multin, + "losses_gaussian": gaussian_loss, + "losses_multinimial": multinomial_loss, + "xstart_mse": xstart_mse, + "mse": mse, + # "mu_mse": mu_mse + "out_mean": out_mean, + "true_mean": true_mean + } + + @torch.no_grad() + def gaussian_ddim_step( + self, + model_out_num, + x, + t, + clip_denoised=False, + denoised_fn=None, + eta=0.0 + ): + out = self.gaussian_p_mean_variance( + model_out_num, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=None, + ) + + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = extract(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = extract(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * torch.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + + noise = torch.randn_like(x) + mean_pred = ( + out["pred_xstart"] * torch.sqrt(alpha_bar_prev) + + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + + return sample + + @torch.no_grad() + def gaussian_ddim_sample( + self, + noise, + T, + out_dict, + eta=0.0 + ): + x = noise + b = x.shape[0] + device = x.device + for t in reversed(range(T)): + print(f'Sample timestep {t:4d}', end='\r') + t_array = (torch.ones(b, device=device) * t).long() + out_num = self._denoise_fn(x, t_array, **out_dict) + x = self.gaussian_ddim_step( + out_num, + x, + t_array + ) + print() + return x + + + @torch.no_grad() + def gaussian_ddim_reverse_step( + self, + model_out_num, + x, + t, + clip_denoised=False, + eta=0.0 + ): + assert eta == 0.0, "Eta must be zero." + out = self.gaussian_p_mean_variance( + model_out_num, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=None, + model_kwargs=None, + ) + + eps = ( + extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = extract(self.alphas_cumprod_next, t, x.shape) + + mean_pred = ( + out["pred_xstart"] * torch.sqrt(alpha_bar_next) + + torch.sqrt(1 - alpha_bar_next) * eps + ) + + return mean_pred + + @torch.no_grad() + def gaussian_ddim_reverse_sample( + self, + x, + T, + out_dict, + ): + b = x.shape[0] + device = x.device + for t in range(T): + print(f'Reverse timestep {t:4d}', end='\r') + t_array = (torch.ones(b, device=device) * t).long() + out_num = self._denoise_fn(x, t_array, **out_dict) + x = self.gaussian_ddim_reverse_step( + out_num, + x, + t_array, + eta=0.0 + ) + print() + + return x + + + @torch.no_grad() + def multinomial_ddim_step( + self, + model_out_cat, + log_x_t, + t, + out_dict, + eta=0.0 + ): + # not ddim, essentially + log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t, t=t, out_dict=out_dict) + + alpha_bar = extract(self.alphas_cumprod, t, log_x_t.shape) + alpha_bar_prev = extract(self.alphas_cumprod_prev, t, log_x_t.shape) + sigma = ( + eta + * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * torch.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + + coef1 = sigma + coef2 = alpha_bar_prev - sigma * alpha_bar + coef3 = 1 - coef1 - coef2 + + + log_ps = torch.stack([ + torch.log(coef1) + log_x_t, + torch.log(coef2) + log_x0, + torch.log(coef3) - torch.log(self.num_classes_expanded) + ], dim=2) + + log_prob = torch.logsumexp(log_ps, dim=2) + + out = self.log_sample_categorical(log_prob) + + return out + + @torch.no_grad() + def sample_ddim(self, num_samples, y_dist): + b = num_samples + device = self.log_alpha.device + z_norm = torch.randn((b, self.num_numerical_features), device=device) + + has_cat = self.num_classes[0] != 0 + log_z = torch.zeros((b, 0), device=device).float() + if has_cat: + uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device) + log_z = self.log_sample_categorical(uniform_logits) + + y = torch.multinomial( + y_dist, + num_samples=b, + replacement=True + ) + out_dict = {'y': y.long().to(device)} + for i in reversed(range(0, self.num_timesteps)): + print(f'Sample timestep {i:4d}', end='\r') + t = torch.full((b,), i, device=device, dtype=torch.long) + model_out = self._denoise_fn( + torch.cat([z_norm, log_z], dim=1).float(), + t, + **out_dict + ) + model_out_num = model_out[:, :self.num_numerical_features] + model_out_cat = model_out[:, self.num_numerical_features:] + z_norm = self.gaussian_ddim_step(model_out_num, z_norm, t, clip_denoised=False) + if has_cat: + log_z = self.multinomial_ddim_step(model_out_cat, log_z, t, out_dict) + + print() + z_ohe = torch.exp(log_z).round() + z_cat = log_z + if has_cat: + z_cat = ohe_to_categories(z_ohe, self.num_classes) + sample = torch.cat([z_norm, z_cat], dim=1).cpu() + return sample, out_dict + + + @torch.no_grad() + def sample(self, num_samples, y_dist): + b = num_samples + device = self.log_alpha.device + z_norm = torch.randn((b, self.num_numerical_features), device=device) + + has_cat = self.num_classes[0] != 0 + log_z = torch.zeros((b, 0), device=device).float() + if has_cat: + uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device) + log_z = self.log_sample_categorical(uniform_logits) + + y = torch.multinomial( + y_dist, + num_samples=b, + replacement=True + ) + out_dict = {'y': y.long().to(device)} + for i in reversed(range(0, self.num_timesteps)): + print(f'Sample timestep {i:4d}', end='\r') + t = torch.full((b,), i, device=device, dtype=torch.long) + model_out = self._denoise_fn( + torch.cat([z_norm, log_z], dim=1).float(), + t, + **out_dict + ) + model_out_num = model_out[:, :self.num_numerical_features] + model_out_cat = model_out[:, self.num_numerical_features:] + z_norm = self.gaussian_p_sample(model_out_num, z_norm, t, clip_denoised=False)['sample'] + if has_cat: + log_z = self.p_sample(model_out_cat, log_z, t, out_dict) + + print() + z_ohe = torch.exp(log_z).round() + z_cat = log_z + if has_cat: + z_cat = ohe_to_categories(z_ohe, self.num_classes) + sample = torch.cat([z_norm, z_cat], dim=1).cpu() + return sample, out_dict + + def sample_all(self, num_samples, batch_size, y_dist, ddim=False): + if ddim: + print('Sample using DDIM.') + sample_fn = self.sample_ddim + else: + sample_fn = self.sample + + b = batch_size + + all_y = [] + all_samples = [] + num_generated = 0 + while num_generated < num_samples: + sample, out_dict = sample_fn(b, y_dist) + mask_nan = torch.any(sample.isnan(), dim=1) + sample = sample[~mask_nan] + out_dict['y'] = out_dict['y'][~mask_nan] + + all_samples.append(sample) + all_y.append(out_dict['y'].cpu()) + if sample.shape[0] != b: + raise FoundNANsError + num_generated += sample.shape[0] + + x_gen = torch.cat(all_samples, dim=0)[:num_samples] + y_gen = torch.cat(all_y, dim=0)[:num_samples] + + return x_gen, y_gen \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py new file mode 100644 index 00000000..472ba5b5 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -0,0 +1,486 @@ +""" +Code was adapted from https://github.com/Yura52/rtdl +""" + +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim +from torch import Tensor + +ModuleType = Union[str, Callable[..., nn.Module]] + +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + +def _is_glu_activation(activation: ModuleType): + return ( + isinstance(activation, str) + and activation.endswith('GLU') + or activation in [ReGLU, GEGLU] + ) + + +def _all_or_none(values): + assert all(x is None for x in values) or all(x is not None for x in values) + +def reglu(x: Tensor) -> Tensor: + """The ReGLU activation function from [1]. + References: + [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 + """ + assert x.shape[-1] % 2 == 0 + a, b = x.chunk(2, dim=-1) + return a * F.relu(b) + + +def geglu(x: Tensor) -> Tensor: + """The GEGLU activation function from [1]. + References: + [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 + """ + assert x.shape[-1] % 2 == 0 + a, b = x.chunk(2, dim=-1) + return a * F.gelu(b) + +class ReGLU(nn.Module): + """The ReGLU activation function from [shazeer2020glu]. + + Examples: + .. testcode:: + + module = ReGLU() + x = torch.randn(3, 4) + assert module(x).shape == (3, 2) + + References: + * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020 + """ + + def forward(self, x: Tensor) -> Tensor: + return reglu(x) + + +class GEGLU(nn.Module): + """The GEGLU activation function from [shazeer2020glu]. + + Examples: + .. testcode:: + + module = GEGLU() + x = torch.randn(3, 4) + assert module(x).shape == (3, 2) + + References: + * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020 + """ + + def forward(self, x: Tensor) -> Tensor: + return geglu(x) + +def _make_nn_module(module_type: ModuleType, *args) -> nn.Module: + return ( + ( + ReGLU() + if module_type == 'ReGLU' + else GEGLU() + if module_type == 'GEGLU' + else getattr(nn, module_type)(*args) + ) + if isinstance(module_type, str) + else module_type(*args) + ) + + +class MLP(nn.Module): + """The MLP model used in [gorishniy2021revisiting]. + + The following scheme describes the architecture: + + .. code-block:: text + + MLP: (in) -> Block -> ... -> Block -> Linear -> (out) + Block: (in) -> Linear -> Activation -> Dropout -> (out) + + Examples: + .. testcode:: + + x = torch.randn(4, 2) + module = MLP.make_baseline(x.shape[1], [3, 5], 0.1, 1) + assert module(x).shape == (len(x), 1) + + References: + * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 + """ + + class Block(nn.Module): + """The main building block of `MLP`.""" + + def __init__( + self, + *, + d_in: int, + d_out: int, + bias: bool, + activation: ModuleType, + dropout: float, + ) -> None: + super().__init__() + self.linear = nn.Linear(d_in, d_out, bias) + self.activation = _make_nn_module(activation) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: Tensor) -> Tensor: + return self.dropout(self.activation(self.linear(x))) + + def __init__( + self, + *, + d_in: int, + d_layers: List[int], + dropouts: Union[float, List[float]], + activation: Union[str, Callable[[], nn.Module]], + d_out: int, + ) -> None: + """ + Note: + `make_baseline` is the recommended constructor. + """ + super().__init__() + if isinstance(dropouts, float): + dropouts = [dropouts] * len(d_layers) + assert len(d_layers) == len(dropouts) + assert activation not in ['ReGLU', 'GEGLU'] + + self.blocks = nn.ModuleList( + [ + MLP.Block( + d_in=d_layers[i - 1] if i else d_in, + d_out=d, + bias=True, + activation=activation, + dropout=dropout, + ) + for i, (d, dropout) in enumerate(zip(d_layers, dropouts)) + ] + ) + self.head = nn.Linear(d_layers[-1] if d_layers else d_in, d_out) + + @classmethod + def make_baseline( + cls: Type['MLP'], + d_in: int, + d_layers: List[int], + dropout: float, + d_out: int, + ) -> 'MLP': + """Create a "baseline" `MLP`. + + This variation of MLP was used in [gorishniy2021revisiting]. Features: + + * :code:`Activation` = :code:`ReLU` + * all linear layers except for the first one and the last one are of the same dimension + * the dropout rate is the same for all dropout layers + + Args: + d_in: the input size + d_layers: the dimensions of the linear layers. If there are more than two + layers, then all of them except for the first and the last ones must + have the same dimension. Valid examples: :code:`[]`, :code:`[8]`, + :code:`[8, 16]`, :code:`[2, 2, 2, 2]`, :code:`[1, 2, 2, 4]`. Invalid + example: :code:`[1, 2, 3, 4]`. + dropout: the dropout rate for all hidden layers + d_out: the output size + Returns: + MLP + + References: + * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 + """ + assert isinstance(dropout, float) + if len(d_layers) > 2: + assert len(set(d_layers[1:-1])) == 1, ( + 'if d_layers contains more than two elements, then' + ' all elements except for the first and the last ones must be equal.' + ) + return MLP( + d_in=d_in, + d_layers=d_layers, # type: ignore + dropouts=dropout, + activation='ReLU', + d_out=d_out, + ) + + def forward(self, x: Tensor) -> Tensor: + x = x.float() + for block in self.blocks: + x = block(x) + x = self.head(x) + return x + + +class ResNet(nn.Module): + """The ResNet model used in [gorishniy2021revisiting]. + The following scheme describes the architecture: + .. code-block:: text + ResNet: (in) -> Linear -> Block -> ... -> Block -> Head -> (out) + |-> Norm -> Linear -> Activation -> Dropout -> Linear -> Dropout ->| + | | + Block: (in) ------------------------------------------------------------> Add -> (out) + Head: (in) -> Norm -> Activation -> Linear -> (out) + Examples: + .. testcode:: + x = torch.randn(4, 2) + module = ResNet.make_baseline( + d_in=x.shape[1], + n_blocks=2, + d_main=3, + d_hidden=4, + dropout_first=0.25, + dropout_second=0.0, + d_out=1 + ) + assert module(x).shape == (len(x), 1) + References: + * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 + """ + + class Block(nn.Module): + """The main building block of `ResNet`.""" + + def __init__( + self, + *, + d_main: int, + d_hidden: int, + bias_first: bool, + bias_second: bool, + dropout_first: float, + dropout_second: float, + normalization: ModuleType, + activation: ModuleType, + skip_connection: bool, + ) -> None: + super().__init__() + self.normalization = _make_nn_module(normalization, d_main) + self.linear_first = nn.Linear(d_main, d_hidden, bias_first) + self.activation = _make_nn_module(activation) + self.dropout_first = nn.Dropout(dropout_first) + self.linear_second = nn.Linear(d_hidden, d_main, bias_second) + self.dropout_second = nn.Dropout(dropout_second) + self.skip_connection = skip_connection + + def forward(self, x: Tensor) -> Tensor: + x_input = x + x = self.normalization(x) + x = self.linear_first(x) + x = self.activation(x) + x = self.dropout_first(x) + x = self.linear_second(x) + x = self.dropout_second(x) + if self.skip_connection: + x = x_input + x + return x + + class Head(nn.Module): + """The final module of `ResNet`.""" + + def __init__( + self, + *, + d_in: int, + d_out: int, + bias: bool, + normalization: ModuleType, + activation: ModuleType, + ) -> None: + super().__init__() + self.normalization = _make_nn_module(normalization, d_in) + self.activation = _make_nn_module(activation) + self.linear = nn.Linear(d_in, d_out, bias) + + def forward(self, x: Tensor) -> Tensor: + if self.normalization is not None: + x = self.normalization(x) + x = self.activation(x) + x = self.linear(x) + return x + + def __init__( + self, + *, + d_in: int, + n_blocks: int, + d_main: int, + d_hidden: int, + dropout_first: float, + dropout_second: float, + normalization: ModuleType, + activation: ModuleType, + d_out: int, + ) -> None: + """ + Note: + `make_baseline` is the recommended constructor. + """ + super().__init__() + + self.first_layer = nn.Linear(d_in, d_main) + if d_main is None: + d_main = d_in + self.blocks = nn.Sequential( + *[ + ResNet.Block( + d_main=d_main, + d_hidden=d_hidden, + bias_first=True, + bias_second=True, + dropout_first=dropout_first, + dropout_second=dropout_second, + normalization=normalization, + activation=activation, + skip_connection=True, + ) + for _ in range(n_blocks) + ] + ) + self.head = ResNet.Head( + d_in=d_main, + d_out=d_out, + bias=True, + normalization=normalization, + activation=activation, + ) + + @classmethod + def make_baseline( + cls: Type['ResNet'], + *, + d_in: int, + n_blocks: int, + d_main: int, + d_hidden: int, + dropout_first: float, + dropout_second: float, + d_out: int, + ) -> 'ResNet': + """Create a "baseline" `ResNet`. + This variation of ResNet was used in [gorishniy2021revisiting]. Features: + * :code:`Activation` = :code:`ReLU` + * :code:`Norm` = :code:`BatchNorm1d` + Args: + d_in: the input size + n_blocks: the number of Blocks + d_main: the input size (or, equivalently, the output size) of each Block + d_hidden: the output size of the first linear layer in each Block + dropout_first: the dropout rate of the first dropout layer in each Block. + dropout_second: the dropout rate of the second dropout layer in each Block. + References: + * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 + """ + return cls( + d_in=d_in, + n_blocks=n_blocks, + d_main=d_main, + d_hidden=d_hidden, + dropout_first=dropout_first, + dropout_second=dropout_second, + normalization='BatchNorm1d', + activation='ReLU', + d_out=d_out, + ) + + def forward(self, x: Tensor) -> Tensor: + x = x.float() + x = self.first_layer(x) + x = self.blocks(x) + x = self.head(x) + return x +#### For diffusion + +class MLPDiffusion(nn.Module): + def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t = 128): + super().__init__() + self.dim_t = dim_t + self.num_classes = num_classes + self.is_y_cond = is_y_cond + + # d0 = rtdl_params['d_layers'][0] + + rtdl_params['d_in'] = dim_t + rtdl_params['d_out'] = d_in + + self.mlp = MLP.make_baseline(**rtdl_params) + + if self.num_classes > 0 and is_y_cond: + self.label_emb = nn.Embedding(self.num_classes, dim_t) + elif self.num_classes == 0 and is_y_cond: + self.label_emb = nn.Linear(1, dim_t) + + self.proj = nn.Linear(d_in, dim_t) + self.time_embed = nn.Sequential( + nn.Linear(dim_t, dim_t), + nn.SiLU(), + nn.Linear(dim_t, dim_t) + ) + + def forward(self, x, timesteps, y=None): + emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) + if self.is_y_cond and y is not None: + if self.num_classes > 0: + y = y.squeeze() + else: + y = y.resize(y.size(0), 1).float() + emb += F.silu(self.label_emb(y)) + x = self.proj(x) + emb + return self.mlp(x) + +class ResNetDiffusion(nn.Module): + def __init__(self, d_in, num_classes, rtdl_params, dim_t = 256): + super().__init__() + self.dim_t = dim_t + self.num_classes = num_classes + + rtdl_params['d_in'] = d_in + rtdl_params['d_out'] = d_in + rtdl_params['emb_d'] = dim_t + self.resnet = ResNet.make_baseline(**rtdl_params) + + if self.num_classes > 0: + self.label_emb = nn.Embedding(self.num_classes, dim_t) + + self.time_embed = nn.Sequential( + nn.Linear(dim_t, dim_t), + nn.SiLU(), + nn.Linear(dim_t, dim_t) + ) + + def forward(self, x, timesteps, y=None): + emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) + if y is not None and self.num_classes > 0: + emb += self.label_emb(y.squeeze()) + return self.resnet(x, emb) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/requirements.txt b/src/synthcity/plugins/core/models/tabular_ddpm/requirements.txt new file mode 100644 index 00000000..acc088c4 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/requirements.txt @@ -0,0 +1,15 @@ +category-encoders==2.3.0 +dython==0.5.1 +icecream==2.1.2 +libzero==0.0.8 +numpy==1.21.4 +optuna==2.10.1 +pandas==1.3.4 +pyarrow==6.0.0 +rtdl==0.0.9 +scikit-learn==1.0.2 +scipy==1.7.2 +skorch==0.11.0 +tomli-w==0.4.0 +tomli==1.2.2 +tqdm==4.62.3 diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py new file mode 100644 index 00000000..6376bfbf --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -0,0 +1,174 @@ +import torch +import numpy as np +import torch.nn.functional as F +from torch.profiler import record_function +from inspect import isfunction + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x < -0.999, + log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + +def sum_except_batch(x, num_dims=1): + ''' + Sums all dimensions except the first. + + Args: + x: Tensor, shape (batch_size, ...) + num_dims: int, number of batch dims (default=1) + + Returns: + x_sum: Tensor, shape (batch_size,) + ''' + return x.reshape(*x.shape[:num_dims], -1).sum(-1) + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +def ohe_to_categories(ohe, K): + K = torch.from_numpy(K) + indices = torch.cat([torch.zeros((1,)), K.cumsum(dim=0)], dim=0).int().tolist() + res = [] + for i in range(len(indices) - 1): + res.append(ohe[:, indices[i]:indices[i+1]].argmax(dim=1)) + return torch.stack(res, dim=1) + +def log_1_min_a(a): + return torch.log(1 - a.exp() + 1e-40) + + +def log_add_exp(a, b): + maximum = torch.max(a, b) + return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) + +def exists(x): + return x is not None + +def extract(a, t, x_shape): + b, *_ = t.shape + t = t.to(a.device) + out = a.gather(-1, t) + while len(out.shape) < len(x_shape): + out = out[..., None] + return out.expand(x_shape) + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def log_categorical(log_x_start, log_prob): + return (log_x_start.exp() * log_prob).sum(dim=1) + +def index_to_log_onehot(x, num_classes): + onehots = [] + for i in range(len(num_classes)): + onehots.append(F.one_hot(x[:, i], num_classes[i])) + + x_onehot = torch.cat(onehots, dim=1) + log_onehot = torch.log(x_onehot.float().clamp(min=1e-30)) + return log_onehot + +def log_sum_exp_by_classes(x, slices): + device = x.device + res = torch.zeros_like(x) + for ixs in slices: + res[:, ixs] = torch.logsumexp(x[:, ixs], dim=1, keepdim=True) + + assert x.size() == res.size() + + return res + +@torch.jit.script +def log_sub_exp(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + m = torch.maximum(a, b) + return torch.log(torch.exp(a - m) - torch.exp(b - m)) + m + +@torch.jit.script +def sliced_logsumexp(x, slices): + lse = torch.logcumsumexp( + torch.nn.functional.pad(x, [1, 0, 0, 0], value=-float('inf')), + dim=-1) + + slice_starts = slices[:-1] + slice_ends = slices[1:] + + slice_lse = log_sub_exp(lse[:, slice_ends], lse[:, slice_starts]) + slice_lse_repeated = torch.repeat_interleave( + slice_lse, + slice_ends - slice_starts, + dim=-1 + ) + return slice_lse_repeated + +def log_onehot_to_index(log_x): + return log_x.argmax(1) + +class FoundNANsError(BaseException): + """Found NANs during sampling""" + def __init__(self, message='Found NANs during sampling.'): + super(FoundNANsError, self).__init__(message) \ No newline at end of file diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py new file mode 100644 index 00000000..999f7312 --- /dev/null +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -0,0 +1,217 @@ +""" +Reference: Kotelnikov, Akim et al. “TabDDPM: Modelling Tabular Data with Diffusion Models.” ArXiv abs/2209.15421 (2022): n. pag. +""" + +# stdlib +from pathlib import Path +from copy import deepcopy +from typing import Any, List, Optional, Union + +# third party +import numpy as np +import pandas as pd + +# Necessary packages +from pydantic import validate_arguments +import torch +from torch.utils.data import sampler + +# synthcity absolute +from synthcity.metrics.weighted_metrics import WeightedMetrics +from synthcity.plugins.core.dataloader import DataLoader +from synthcity.plugins.core.distribution import ( + CategoricalDistribution, + Distribution, + FloatDistribution, + IntegerDistribution, +) +from synthcity.plugins.core.models.tabular_ddpm import GaussianMultinomialDiffusion, MLPDiffusion, ResNetDiffusion +from synthcity.plugins.core.plugin import Plugin +from synthcity.plugins.core.schema import Schema +from synthcity.utils.constants import DEVICE + + +class DDPMPlugin(Plugin): + """ + .. inheritance-diagram:: synthcity.plugins.generic.plugin_tab_ddpm.TabDDPMPlugin + :parts: 1 + + + Tabular denoising diffusion probabilistic model. + + Args: + ... + + Example: + >>> from sklearn.datasets import load_iris + >>> from synthcity.plugins import Plugins + >>> X, y = load_iris(as_frame = True, return_X_y = True) + >>> X["target"] = y + >>> plugin = Plugins().get("ddpm", n_iter = 100) + >>> plugin.fit(X) + >>> plugin.generate(50) + + """ + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def __init__( + self, + n_iter = 1000, + lr = 0.002, + weight_decay = 1e-4, + batch_size = 1024, + model_type = 'mlp', + model_params = None, + num_timesteps = 1000, + gaussian_loss_type = 'mse', + scheduler = 'cosine', + change_val = False, + device: Any = DEVICE, + # early stopping + n_iter_min: int = 100, + n_iter_print: int = 50, + patience: int = 5, + patience_metric: Optional[WeightedMetrics] = None, + # core plugin arguments + random_state: int = 0, + workspace: Path = Path("workspace"), + compress_dataset: bool = False, + sampling_patience: int = 500, + **kwargs: Any + ) -> None: + super().__init__( + device=device, + random_state=random_state, + sampling_patience=sampling_patience, + workspace=workspace, + compress_dataset=compress_dataset, + **kwargs + ) + + if patience_metric is None: + patience_metric = WeightedMetrics( + metrics=[("detection", "detection_mlp")], + weights=[1], + workspace=workspace, + ) + + self.__dict__.update(locals()) + del self.self, self.kwargs + + @staticmethod + def name() -> str: + return "ddpm" + + @staticmethod + def type() -> str: + return "generic" + + @staticmethod + def hyperparameter_space(**kwargs: Any) -> List[Distribution]: + raise NotImplementedError + + def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "DDPMPlugin": + + if self.model_type == 'mlp': + self.model = MLPDiffusion(**self.model_params) + elif self.model_type == 'resnet': + self.model = ResNetDiffusion(**self.model_params) + else: + raise "Unknown model!" + + self.diffusion = GaussianMultinomialDiffusion( + num_classes=num_classes, + num_numerical_features=num_numerical_features, + denoise_fn=self.model, + gaussian_loss_type=self.gaussian_loss_type, + num_timesteps=self.num_timesteps, + scheduler=self.scheduler, + device=self.device + ).to(self.device).train() + + trainer = Trainer( + self.model, + X, + lr=self.lr, + weight_decay=self.weight_decay, + steps=self.n_iter, + device=self.device + ) + + trainer.run_loop() + return self + + def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader: + self.diffusion.eval() + return self._safe_generate(self.model.sample_all, count, syn_schema) + + + +class Trainer: + def __init__(self, diffusion, train_iter, lr, weight_decay, steps, device=DEVICE): + self.diffusion = diffusion + self.ema_model = deepcopy(self.diffusion._denoise_fn) + for param in self.ema_model.parameters(): + param.detach_() + + self.train_iter = train_iter + self.steps = steps + self.init_lr = lr + self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay) + self.device = device + self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) + self.log_every = 100 + self.print_every = 500 + self.ema_every = 1000 + + def _anneal_lr(self, step): + frac_done = step / self.steps + lr = self.init_lr * (1 - frac_done) + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + def _run_step(self, x, out_dict): + x = x.to(self.device) + for k in out_dict: + out_dict[k] = out_dict[k].long().to(self.device) + self.optimizer.zero_grad() + loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict) + loss = loss_multi + loss_gauss + loss.backward() + self.optimizer.step() + + return loss_multi, loss_gauss + + def run_loop(self): + step = 0 + curr_loss_multi = 0.0 + curr_loss_gauss = 0.0 + + curr_count = 0 + while step < self.steps: + x, out_dict = next(self.train_iter) + out_dict = {'y': out_dict} + batch_loss_multi, batch_loss_gauss = self._run_step(x, out_dict) + + self._anneal_lr(step) + + curr_count += len(x) + curr_loss_multi += batch_loss_multi.item() * len(x) + curr_loss_gauss += batch_loss_gauss.item() * len(x) + + if (step + 1) % self.log_every == 0: + mloss = np.around(curr_loss_multi / curr_count, 4) + gloss = np.around(curr_loss_gauss / curr_count, 4) + if (step + 1) % self.print_every == 0: + print(f'Step {(step + 1)}/{self.steps} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') + self.loss_history.loc[len(self.loss_history)] =[step + 1, mloss, gloss, mloss + gloss] + curr_count = 0 + curr_loss_gauss = 0.0 + curr_loss_multi = 0.0 + + update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters()) + + step += 1 + + +plugin = DDPMPlugin diff --git a/third-party/tab-ddpm b/third-party/tab-ddpm new file mode 160000 index 00000000..41f2415a --- /dev/null +++ b/third-party/tab-ddpm @@ -0,0 +1 @@ +Subproject commit 41f2415a378f1e8e8f4f5c3b8736521c0d47cf22 From fed898b86dd68a47bbf105877de69217add0d0a0 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 3 Mar 2023 13:55:30 +0100 Subject: [PATCH 02/95] Add DDPM test script and update DDPM plugin --- .../core/models/tabular_ddpm/__init__.py | 36 ++++- .../gaussian_multinomial_diffsuion.py | 31 ++--- .../plugins/core/models/tabular_ddpm/utils.py | 2 +- src/synthcity/plugins/generic/plugin_ddpm.py | 127 +++++++++--------- tests/plugins/generic/test_ddpm.py | 125 +++++++++++++++++ 5 files changed, 234 insertions(+), 87 deletions(-) create mode 100644 tests/plugins/generic/test_ddpm.py diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 80d346c2..6dfe0bb3 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -1,2 +1,36 @@ from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion # noqa -from .modules import MLPDiffusion, ResNetDiffusion # noqa \ No newline at end of file +from .modules import MLPDiffusion, ResNetDiffusion # noqa + +# stdlib +from copy import deepcopy +from typing import Any, Optional, Union + +# third party +import numpy as np +import pandas as pd +import torch +from pydantic import validate_arguments +from sklearn.preprocessing import OneHotEncoder +from torch import nn + +# synthcity absolute +from synthcity.utils.constants import DEVICE +from synthcity.utils.samplers import BaseSampler, ConditionalDatasetSampler + +# synthcity relative +from ..tabular_encoder import TabularEncoder + + +# class TabDDPM(nn.Module): +# def __init__( +# self, +# X: pd.DataFrame, + +# def generate(self, n_samples: int) -> pd.DataFrame: +# self.eval() +# with torch.no_grad(): +# samples = self.diffusion.sample(n_samples) +# return samples + +# def forward(self, count: int) -> pd.DataFrame: +# pass \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 0d0f2ce4..7e93c070 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -6,6 +6,7 @@ import torch.nn.functional as F import torch import math +import pandas as pd import numpy as np from .utils import * @@ -13,7 +14,6 @@ """ Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 """ -eps = 1e-8 def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): """ @@ -59,12 +59,13 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return np.array(betas) + class GaussianMultinomialDiffusion(torch.nn.Module): def __init__( self, - num_classes: np.array, - num_numerical_features: int, denoise_fn, + num_numerical_features, + num_classes=None, num_timesteps=1000, gaussian_loss_type='mse', gaussian_parametrization='eps', @@ -83,7 +84,7 @@ def __init__( ' This is expensive both in terms of memory and computation.') self.num_numerical_features = num_numerical_features - self.num_classes = num_classes # it as a vector [K1, K2, ..., Km] + self.num_classes = num_classes or [0] self.num_classes_expanded = torch.from_numpy( np.concatenate([num_classes[i].repeat(num_classes[i]) for i in range(len(num_classes))]) ).to(device) @@ -213,7 +214,6 @@ def gaussian_p_mean_variance( model_variance = extract(model_variance, t, x.shape) model_log_variance = extract(model_log_variance, t, x.shape) - if self.gaussian_parametrization == 'eps': pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) elif self.gaussian_parametrization == 'x0': @@ -299,7 +299,7 @@ def _gaussian_loss(self, model_out, x_start, x_t, t, noise, model_kwargs=None): return terms['loss'] - def _predict_xstart_from_eps(self, x_t, t, eps): + def _predict_xstart_from_eps(self, x_t, t, eps=1e-8): assert x_t.shape == eps.shape return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t @@ -702,7 +702,6 @@ def mixed_elbo(self, x0, out_dict): out_mean = torch.stack(out_mean, dim=1) true_mean = torch.stack(true_mean, dim=1) - prior_gauss = self._prior_gaussian(x_num) prior_multin = torch.tensor([0.0]) @@ -920,7 +919,6 @@ def sample_ddim(self, num_samples, y_dist): z_cat = ohe_to_categories(z_ohe, self.num_classes) sample = torch.cat([z_norm, z_cat], dim=1).cpu() return sample, out_dict - @torch.no_grad() def sample(self, num_samples, y_dist): @@ -962,31 +960,28 @@ def sample(self, num_samples, y_dist): sample = torch.cat([z_norm, z_cat], dim=1).cpu() return sample, out_dict - def sample_all(self, num_samples, batch_size, y_dist, ddim=False): + def sample_all(self, num_samples, y_dist, max_batch_size=2000, ddim=False): if ddim: print('Sample using DDIM.') sample_fn = self.sample_ddim else: sample_fn = self.sample - - b = batch_size + bs = np.diff(list(range(0, num_samples, max_batch_size)) + [num_samples]) all_y = [] all_samples = [] - num_generated = 0 - while num_generated < num_samples: + + for b in bs: sample, out_dict = sample_fn(b, y_dist) mask_nan = torch.any(sample.isnan(), dim=1) sample = sample[~mask_nan] + if sample.shape[0] != b: + raise FoundNANsError out_dict['y'] = out_dict['y'][~mask_nan] - all_samples.append(sample) all_y.append(out_dict['y'].cpu()) - if sample.shape[0] != b: - raise FoundNANsError - num_generated += sample.shape[0] x_gen = torch.cat(all_samples, dim=0)[:num_samples] y_gen = torch.cat(all_y, dim=0)[:num_samples] - return x_gen, y_gen \ No newline at end of file + return x_gen, y_gen diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index 6376bfbf..95abd42a 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -171,4 +171,4 @@ def log_onehot_to_index(log_x): class FoundNANsError(BaseException): """Found NANs during sampling""" def __init__(self, message='Found NANs during sampling.'): - super(FoundNANsError, self).__init__(message) \ No newline at end of file + super(FoundNANsError, self).__init__(message) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 999f7312..5565eb92 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -26,12 +26,13 @@ IntegerDistribution, ) from synthcity.plugins.core.models.tabular_ddpm import GaussianMultinomialDiffusion, MLPDiffusion, ResNetDiffusion +from synthcity.plugins.core.models.tabular_encoder import TabularEncoder from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema from synthcity.utils.constants import DEVICE -class DDPMPlugin(Plugin): +class TabDDPMPlugin(Plugin): """ .. inheritance-diagram:: synthcity.plugins.generic.plugin_tab_ddpm.TabDDPMPlugin :parts: 1 @@ -67,6 +68,8 @@ def __init__( scheduler = 'cosine', change_val = False, device: Any = DEVICE, + log_interval: int = 100, + print_interval: int = 500, # early stopping n_iter_min: int = 100, n_iter_print: int = 50, @@ -87,7 +90,7 @@ def __init__( compress_dataset=compress_dataset, **kwargs ) - + if patience_metric is None: patience_metric = WeightedMetrics( metrics=[("detection", "detection_mlp")], @@ -110,88 +113,72 @@ def type() -> str: def hyperparameter_space(**kwargs: Any) -> List[Distribution]: raise NotImplementedError - def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "DDPMPlugin": + def _anneal_lr(self, step): + frac_done = step / self.steps + lr = self.lr * (1 - frac_done) + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + def _one_step(self, x, out_dict): + x = x.to(self.device) + for k in out_dict: + out_dict[k] = out_dict[k].long().to(self.device) + self.optimizer.zero_grad() + loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict) + loss = loss_multi + loss_gauss + loss.backward() + self.optimizer.step() + return loss_multi, loss_gauss + def _update_ema(self, target_params, source_params, rate=0.999): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) + + def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": + # TODO: add parameters of TabularEncoder + self.encoder = TabularEncoder().fit(X) + if self.model_type == 'mlp': self.model = MLPDiffusion(**self.model_params) elif self.model_type == 'resnet': self.model = ResNetDiffusion(**self.model_params) else: raise "Unknown model!" - + self.diffusion = GaussianMultinomialDiffusion( - num_classes=num_classes, - num_numerical_features=num_numerical_features, denoise_fn=self.model, + num_numerical_features=self.encoder.n_features(), gaussian_loss_type=self.gaussian_loss_type, num_timesteps=self.num_timesteps, scheduler=self.scheduler, device=self.device - ).to(self.device).train() + ).to(self.device) - trainer = Trainer( - self.model, - X, - lr=self.lr, - weight_decay=self.weight_decay, - steps=self.n_iter, - device=self.device - ) - - trainer.run_loop() - return self - - def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader: - self.diffusion.eval() - return self._safe_generate(self.model.sample_all, count, syn_schema) - - - -class Trainer: - def __init__(self, diffusion, train_iter, lr, weight_decay, steps, device=DEVICE): - self.diffusion = diffusion - self.ema_model = deepcopy(self.diffusion._denoise_fn) + self.ema_model = deepcopy(self.model) for param in self.ema_model.parameters(): param.detach_() - self.train_iter = train_iter - self.steps = steps - self.init_lr = lr - self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay) - self.device = device + self.optimizer = torch.optim.AdamW( + self.diffusion.parameters(), lr=self.lr, weight_decay=self.weight_decay) + + # TODO: check data type of X self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) - self.log_every = 100 - self.print_every = 500 - self.ema_every = 1000 - - def _anneal_lr(self, step): - frac_done = step / self.steps - lr = self.init_lr * (1 - frac_done) - for param_group in self.optimizer.param_groups: - param_group["lr"] = lr - - def _run_step(self, x, out_dict): - x = x.to(self.device) - for k in out_dict: - out_dict[k] = out_dict[k].long().to(self.device) - self.optimizer.zero_grad() - loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict) - loss = loss_multi + loss_gauss - loss.backward() - self.optimizer.step() - - return loss_multi, loss_gauss - - def run_loop(self): - step = 0 + curr_loss_multi = 0.0 curr_loss_gauss = 0.0 curr_count = 0 - while step < self.steps: - x, out_dict = next(self.train_iter) + for step in range(self.n_iter): + x, out_dict = next(X) out_dict = {'y': out_dict} - batch_loss_multi, batch_loss_gauss = self._run_step(x, out_dict) + batch_loss_multi, batch_loss_gauss = self._one_step(x, out_dict) self._anneal_lr(step) @@ -199,19 +186,25 @@ def run_loop(self): curr_loss_multi += batch_loss_multi.item() * len(x) curr_loss_gauss += batch_loss_gauss.item() * len(x) - if (step + 1) % self.log_every == 0: + if (step + 1) % self.log_interval == 0: mloss = np.around(curr_loss_multi / curr_count, 4) gloss = np.around(curr_loss_gauss / curr_count, 4) - if (step + 1) % self.print_every == 0: + if (step + 1) % self.print_interval == 0: print(f'Step {(step + 1)}/{self.steps} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') - self.loss_history.loc[len(self.loss_history)] =[step + 1, mloss, gloss, mloss + gloss] + self.loss_history.loc[len(self.loss_history)] = [ + step + 1, mloss, gloss, mloss + gloss] curr_count = 0 curr_loss_gauss = 0.0 curr_loss_multi = 0.0 - update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters()) + self._update_ema(self.ema_model.parameters(), self.model.parameters()) + + return self - step += 1 + def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader: + self.diffusion.eval() + # TODO: check self.model.sample_all + return self._safe_generate(self.diffusion.sample_all, count, syn_schema) -plugin = DDPMPlugin +plugin = TabDDPMPlugin diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py new file mode 100644 index 00000000..ec112f48 --- /dev/null +++ b/tests/plugins/generic/test_ddpm.py @@ -0,0 +1,125 @@ +# third party +import numpy as np +import pandas as pd +import pytest +from generic_helpers import generate_fixtures +from sklearn.datasets import load_iris + +# synthcity absolute +from synthcity.metrics.eval import PerformanceEvaluatorXGB +from synthcity.plugins import Plugin +from synthcity.plugins.core.constraints import Constraints +from synthcity.plugins.core.dataloader import GenericDataLoader +from synthcity.plugins.generic.plugin_ddpm import plugin + +plugin_name = "ddpm" +plugin_args = {"n_iter": 100} + + +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_sanity(test_plugin: Plugin) -> None: + assert test_plugin is not None + + +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_name(test_plugin: Plugin) -> None: + assert test_plugin.name() == plugin_name + + +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_type(test_plugin: Plugin) -> None: + assert test_plugin.type() == "generic" + + +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_hyperparams(test_plugin: Plugin) -> None: + assert len(test_plugin.hyperparameter_space()) == 9 + + +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_fit(test_plugin: Plugin) -> None: + X = pd.DataFrame(load_iris()["data"]) + test_plugin.fit(GenericDataLoader(X)) + + +def test_plugin_generate() -> None: + test_plugin = plugin(n_layers_hidden=2, n_units_hidden=100, n_iter=50) + X = pd.DataFrame(load_iris()["data"]) + test_plugin.fit(GenericDataLoader(X)) + + X_gen = test_plugin.generate() + assert len(X_gen) == len(X) + assert test_plugin.schema_includes(X_gen) + + X_gen = test_plugin.generate(50) + assert len(X_gen) == 50 + assert test_plugin.schema_includes(X_gen) + + +def test_plugin_generate_constraints() -> None: + test_plugin = plugin(n_layers_hidden=2, n_units_hidden=100, n_iter=50) + X = pd.DataFrame(load_iris()["data"]) + test_plugin.fit(GenericDataLoader(X)) + + constraints = Constraints( + rules=[ + ("0", "le", 6), + ("0", "ge", 4.3), + ("1", "le", 4.4), + ("1", "ge", 3), + ("2", "le", 5.5), + ("2", "ge", 1.0), + ("3", "le", 2), + ("3", "ge", 0.1), + ] + ) + + X_gen = test_plugin.generate(constraints=constraints).dataframe() + assert len(X_gen) == len(X) + assert test_plugin.schema_includes(X_gen) + assert constraints.filter(X_gen).sum() == len(X_gen) + + X_gen = test_plugin.generate(count=50, constraints=constraints).dataframe() + assert len(X_gen) == 50 + assert test_plugin.schema_includes(X_gen) + assert constraints.filter(X_gen).sum() == len(X_gen) + assert list(X_gen.columns) == list(X.columns) + + +def test_sample_hyperparams() -> None: + for i in range(100): + args = plugin.sample_hyperparameters() + + assert plugin(**args) is not None + + +@pytest.mark.slow +@pytest.mark.parametrize("compress_dataset", [True, False]) +def test_eval_performance_nflow(compress_dataset: bool) -> None: + results = [] + + Xraw, y = load_iris(return_X_y=True, as_frame=True) + Xraw["target"] = y + X = GenericDataLoader(Xraw) + + for retry in range(2): + test_plugin = plugin(n_iter=5000, compress_dataset=compress_dataset) + evaluator = PerformanceEvaluatorXGB() + + test_plugin.fit(X) + X_syn = test_plugin.generate() + + results.append(evaluator.evaluate(X, X_syn)["syn_id"]) + + print(plugin.name(), results) + assert np.mean(results) > 0.8 From 34979cf240acbaa75a4d72998f18a5ff8b2a9100 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 5 Mar 2023 21:15:17 +0100 Subject: [PATCH 03/95] add TabDDPM class and refactor --- .../core/models/tabular_ddpm/__init__.py | 143 ++++++++++++++--- .../gaussian_multinomial_diffsuion.py | 57 +++++-- .../core/models/tabular_ddpm/modules.py | 21 ++- .../plugins/core/models/tabular_ddpm/utils.py | 38 +++++ src/synthcity/plugins/generic/plugin_ddpm.py | 149 ++++++------------ 5 files changed, 263 insertions(+), 145 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 6dfe0bb3..46977c79 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -1,6 +1,3 @@ -from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion # noqa -from .modules import MLPDiffusion, ResNetDiffusion # noqa - # stdlib from copy import deepcopy from typing import Any, Optional, Union @@ -9,28 +6,136 @@ import numpy as np import pandas as pd import torch -from pydantic import validate_arguments -from sklearn.preprocessing import OneHotEncoder from torch import nn +from pydantic import validate_arguments # synthcity absolute from synthcity.utils.constants import DEVICE -from synthcity.utils.samplers import BaseSampler, ConditionalDatasetSampler +from synthcity.metrics.weighted_metrics import WeightedMetrics -# synthcity relative -from ..tabular_encoder import TabularEncoder +from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion # noqa +from .modules import MLPDiffusion, ResNetDiffusion # noqa +from .utils import TensorDataLoader + + +class TabDDPM(nn.Module): + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def __init__( + self, + n_iter = 10000, + lr = 0.002, + weight_decay = 1e-4, + batch_size = 1024, + num_timesteps = 1000, + gaussian_loss_type = 'mse', + scheduler = 'cosine', + device: Any = DEVICE, + log_interval: int = 100, + print_interval: int = 500, + # model params + model_type = 'mlp', + rtdl_params: Optional[dict] = None, # {'d_layers', 'dropout'} + dim_label_emb: int = 128, + # early stopping + n_iter_min: int = 100, + n_iter_print: int = 50, + patience: int = 5, + ) -> None: + super().__init__() + self.__dict__.update(locals()) + del self.self, self.kwargs + + def _anneal_lr(self, step): + frac_done = step / self.steps + lr = self.lr * (1 - frac_done) + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + def _update_ema(self, target_params, source_params, rate=0.999): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) -# class TabDDPM(nn.Module): -# def __init__( -# self, -# X: pd.DataFrame, + def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): + if cond is not None: + n_classes = len(np.unique(cond)) + else: + n_classes = 0 + + model_params = dict( + num_classes=n_classes, + is_y_cond=cond is not None, + rtdl_params=self.rtdl_params, + dim_t = self.dim_label_emb + ) -# def generate(self, n_samples: int) -> pd.DataFrame: -# self.eval() -# with torch.no_grad(): -# samples = self.diffusion.sample(n_samples) -# return samples + tensors = [X] if cond is None else [X, cond] + tensors = [torch.tensor(t.values, dtype=torch.float32, device=self.device) for t in tensors] + self.dataloader = TensorDataLoader(tensors, batch_size=self.batch_size) + + self.diffusion = GaussianMultinomialDiffusion( + model_type=self.model_type, + model_params=model_params, + num_numerical_features=self.encoder.n_features(), + gaussian_loss_type=self.gaussian_loss_type, + num_timesteps=self.num_timesteps, + scheduler=self.scheduler, + device=self.device + ).to(self.device) + + self.ema_model = deepcopy(self.diffusion.denoise_fn) + for param in self.ema_model.parameters(): + param.detach_() + + self.optimizer = torch.optim.AdamW( + self.diffusion.parameters(), lr=self.lr, weight_decay=self.weight_decay) -# def forward(self, count: int) -> pd.DataFrame: -# pass \ No newline at end of file + self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) + + for step, (x, y) in enumerate(self.dataloader): + curr_loss_multi = 0.0 + curr_loss_gauss = 0.0 + curr_count = 0 + self.diffusion.train() + + self.optimizer.zero_grad() + loss_multi, loss_gauss = self.diffusion.mixed_loss(x, dict(y=y)) + loss = loss_multi + loss_gauss + loss.backward() + self.optimizer.step() + + self._anneal_lr(step) + + curr_count += len(x) + curr_loss_multi += loss_multi.item() * len(x) + curr_loss_gauss += loss_gauss.item() * len(x) + + if (step + 1) % self.log_interval == 0: + mloss = np.around(curr_loss_multi / curr_count, 4) + gloss = np.around(curr_loss_gauss / curr_count, 4) + if (step + 1) % self.print_interval == 0: + print(f'Step {(step + 1)}/{self.n_iter} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') + self.loss_history.loc[len(self.loss_history)] = [ + step + 1, mloss, gloss, mloss + gloss] + curr_count = 0 + curr_loss_gauss = 0.0 + curr_loss_multi = 0.0 + + self._update_ema(self.ema_model.parameters(), self.model.parameters()) + + if step == self.n_iter - 1: + break + + return self + + def generate(self, count: int, cond=None): + self.diffusion.eval() + sample, out_dict = self.diffusion.sample_all(count) + return sample, out_dict['y'] diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 7e93c070..70dd0a9f 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -10,6 +10,7 @@ import numpy as np from .utils import * +from .modules import MLPDiffusion, ResNetDiffusion """ Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 @@ -63,9 +64,10 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): class GaussianMultinomialDiffusion(torch.nn.Module): def __init__( self, - denoise_fn, num_numerical_features, - num_classes=None, + num_classes, + model_type='mlp', + model_params=None, num_timesteps=1000, gaussian_loss_type='mse', gaussian_parametrization='eps', @@ -84,18 +86,41 @@ def __init__( ' This is expensive both in terms of memory and computation.') self.num_numerical_features = num_numerical_features - self.num_classes = num_classes or [0] + self.num_classes = num_classes self.num_classes_expanded = torch.from_numpy( np.concatenate([num_classes[i].repeat(num_classes[i]) for i in range(len(num_classes))]) ).to(device) - + self.dim_input = num_numerical_features + sum(self.num_classes) + self.slices_for_classes = [np.arange(self.num_classes[0])] offsets = np.cumsum(self.num_classes) for i in range(1, len(offsets)): self.slices_for_classes.append(np.arange(offsets[i - 1], offsets[i])) self.offsets = torch.from_numpy(np.append([0], offsets)).to(device) - self._denoise_fn = denoise_fn + if model_params is None: + model_params = dict( + d_in = self.dim_input, + num_classes = 0, + is_y_cond = False, + rtdl_params = None + ) + else: + model_params['d_in'] = self.dim_input + + if model_params['rtdl_params'] is None: + model_params['rtdl_params'] = dict( + d_layers = [256, 256, 256], + dropout = 0.0 + ) + + if model_type == 'mlp': + self.denoise_fn = MLPDiffusion(**model_params) + elif model_type == 'resnet': + self.denoise_fn = ResNetDiffusion(**model_params) + else: + raise "Unknown diffusion model type!" + self.gaussian_loss_type = gaussian_loss_type self.gaussian_parametrization = gaussian_parametrization self.multinomial_loss_type = multinomial_loss_type @@ -607,7 +632,7 @@ def mixed_loss(self, x, out_dict): x_in = torch.cat([x_num_t, log_x_cat_t], dim=1) - model_out = self._denoise_fn( + model_out = self.denoise_fn( x_in, t, **out_dict @@ -619,7 +644,8 @@ def mixed_loss(self, x, out_dict): loss_multi = torch.zeros((1,)).float() loss_gauss = torch.zeros((1,)).float() if x_cat.shape[1] > 0: - loss_multi = self._multinomial_loss(model_out_cat, log_x_cat, log_x_cat_t, t, pt, out_dict) / len(self.num_classes) + loss_multi = self._multinomial_loss(model_out_cat, log_x_cat, log_x_cat_t, t, pt, + out_dict) / len(self.num_classes) if x_num.shape[1] > 0: loss_gauss = self._gaussian_loss(model_out_num, x_num, x_num_t, t, noise) @@ -657,7 +683,7 @@ def mixed_elbo(self, x0, out_dict): else: log_x_cat_t = x_cat - model_out = self._denoise_fn( + model_out = self.denoise_fn( torch.cat([x_num_t, log_x_cat_t], dim=1), t_array, **out_dict @@ -777,7 +803,7 @@ def gaussian_ddim_sample( for t in reversed(range(T)): print(f'Sample timestep {t:4d}', end='\r') t_array = (torch.ones(b, device=device) * t).long() - out_num = self._denoise_fn(x, t_array, **out_dict) + out_num = self.denoise_fn(x, t_array, **out_dict) x = self.gaussian_ddim_step( out_num, x, @@ -831,7 +857,7 @@ def gaussian_ddim_reverse_sample( for t in range(T): print(f'Reverse timestep {t:4d}', end='\r') t_array = (torch.ones(b, device=device) * t).long() - out_num = self._denoise_fn(x, t_array, **out_dict) + out_num = self.denoise_fn(x, t_array, **out_dict) x = self.gaussian_ddim_reverse_step( out_num, x, @@ -881,7 +907,7 @@ def multinomial_ddim_step( return out @torch.no_grad() - def sample_ddim(self, num_samples, y_dist): + def sample_ddim(self, num_samples, y_dist=None): b = num_samples device = self.log_alpha.device z_norm = torch.randn((b, self.num_numerical_features), device=device) @@ -901,7 +927,7 @@ def sample_ddim(self, num_samples, y_dist): for i in reversed(range(0, self.num_timesteps)): print(f'Sample timestep {i:4d}', end='\r') t = torch.full((b,), i, device=device, dtype=torch.long) - model_out = self._denoise_fn( + model_out = self.denoise_fn( torch.cat([z_norm, log_z], dim=1).float(), t, **out_dict @@ -921,7 +947,8 @@ def sample_ddim(self, num_samples, y_dist): return sample, out_dict @torch.no_grad() - def sample(self, num_samples, y_dist): + def sample(self, num_samples, y_dist=None): + # TODO: handle y_dist=None b = num_samples device = self.log_alpha.device z_norm = torch.randn((b, self.num_numerical_features), device=device) @@ -941,7 +968,7 @@ def sample(self, num_samples, y_dist): for i in reversed(range(0, self.num_timesteps)): print(f'Sample timestep {i:4d}', end='\r') t = torch.full((b,), i, device=device, dtype=torch.long) - model_out = self._denoise_fn( + model_out = self.denoise_fn( torch.cat([z_norm, log_z], dim=1).float(), t, **out_dict @@ -960,7 +987,7 @@ def sample(self, num_samples, y_dist): sample = torch.cat([z_norm, z_cat], dim=1).cpu() return sample, out_dict - def sample_all(self, num_samples, y_dist, max_batch_size=2000, ddim=False): + def sample_all(self, num_samples, y_dist=None, max_batch_size=2000, ddim=False): if ddim: print('Sample using DDIM.') sample_fn = self.sample_ddim diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index 472ba5b5..a6164119 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -420,6 +420,7 @@ def forward(self, x: Tensor) -> Tensor: x = self.blocks(x) x = self.head(x) return x + #### For diffusion class MLPDiffusion(nn.Module): @@ -460,7 +461,7 @@ def forward(self, x, timesteps, y=None): return self.mlp(x) class ResNetDiffusion(nn.Module): - def __init__(self, d_in, num_classes, rtdl_params, dim_t = 256): + def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t = 256): super().__init__() self.dim_t = dim_t self.num_classes = num_classes @@ -469,10 +470,13 @@ def __init__(self, d_in, num_classes, rtdl_params, dim_t = 256): rtdl_params['d_out'] = d_in rtdl_params['emb_d'] = dim_t self.resnet = ResNet.make_baseline(**rtdl_params) - - if self.num_classes > 0: + + if self.num_classes > 0 and is_y_cond: self.label_emb = nn.Embedding(self.num_classes, dim_t) + elif self.num_classes == 0 and is_y_cond: + self.label_emb = nn.Linear(1, dim_t) + self.proj = nn.Linear(d_in, dim_t) self.time_embed = nn.Sequential( nn.Linear(dim_t, dim_t), nn.SiLU(), @@ -481,6 +485,11 @@ def __init__(self, d_in, num_classes, rtdl_params, dim_t = 256): def forward(self, x, timesteps, y=None): emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) - if y is not None and self.num_classes > 0: - emb += self.label_emb(y.squeeze()) - return self.resnet(x, emb) + if self.is_y_cond and y is not None: + if self.num_classes > 0: + y = y.squeeze() + else: + y = y.resize(y.size(0), 1).float() + emb += F.silu(self.label_emb(y)) + x = self.proj(x) + emb + return self.resnet(x) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index 95abd42a..a0021a68 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -168,7 +168,45 @@ def sliced_logsumexp(x, slices): def log_onehot_to_index(log_x): return log_x.argmax(1) + class FoundNANsError(BaseException): """Found NANs during sampling""" def __init__(self, message='Found NANs during sampling.'): super(FoundNANsError, self).__init__(message) + + +class TensorDataLoader: + """ + A DataLoader-like object for a set of tensors that can be much faster than + TensorDataset + DataLoader because dataloader grabs individual indices of + the dataset and calls cat (slow). + Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 + """ + def __init__(self, *tensors, batch_size=32, shuffle=False): + """ + Initialize a FastTensorDataLoader. + :param *tensors: tensors to store. Must have the same length @ dim 0. + :param batch_size: batch size to load. + :param shuffle: if True, shuffle the data *in-place* whenever an + iterator is created out of this object. + :returns: A FastTensorDataLoader. + """ + assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) + self.tensors = tensors + self.dataset_len = self.tensors[0].shape[0] + self.batch_size = batch_size + self.shuffle = shuffle + + def __iter__(self): + i = 0 + idx = np.arange(self.dataset_len) + if self.shuffle: + np.random.shuffle(idx) + while True: + j = i + self.batch_size + s = slice(i, j) + if j > self.dataset_len: + s = list(range(i, self.dataset_len)) + list(range(0, j - self.dataset_len)) + if self.shuffle: + np.random.shuffle(idx) + yield tuple(t[idx[s]] for t in self.tensors) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 5565eb92..6ed9fd89 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -25,7 +25,7 @@ FloatDistribution, IntegerDistribution, ) -from synthcity.plugins.core.models.tabular_ddpm import GaussianMultinomialDiffusion, MLPDiffusion, ResNetDiffusion +from synthcity.plugins.core.models.tabular_ddpm import TabDDPM from synthcity.plugins.core.models.tabular_encoder import TabularEncoder from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema @@ -57,24 +57,27 @@ class TabDDPMPlugin(Plugin): @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__( self, + *, + is_classification: bool = False, n_iter = 1000, lr = 0.002, weight_decay = 1e-4, batch_size = 1024, model_type = 'mlp', - model_params = None, num_timesteps = 1000, gaussian_loss_type = 'mse', scheduler = 'cosine', - change_val = False, device: Any = DEVICE, log_interval: int = 100, print_interval: int = 500, + # model params + rtdl_params: Optional[dict] = None, # {'d_layers', 'dropout'} + dim_label_emb: int = 128, # early stopping n_iter_min: int = 100, n_iter_print: int = 50, patience: int = 5, - patience_metric: Optional[WeightedMetrics] = None, + # patience_metric: Optional[WeightedMetrics] = None, # core plugin arguments random_state: int = 0, workspace: Path = Path("workspace"), @@ -90,16 +93,27 @@ def __init__( compress_dataset=compress_dataset, **kwargs ) - - if patience_metric is None: - patience_metric = WeightedMetrics( - metrics=[("detection", "detection_mlp")], - weights=[1], - workspace=workspace, - ) - - self.__dict__.update(locals()) - del self.self, self.kwargs + + self.is_classification = is_classification + + self.model = TabDDPM( + n_iter=n_iter, + lr=lr, + weight_decay=weight_decay, + batch_size=batch_size, + num_timesteps=num_timesteps, + gaussian_loss_type=gaussian_loss_type, + scheduler=scheduler, + device=device, + log_interval=log_interval, + print_interval=print_interval, + model_type=model_type, + rtdl_params=rtdl_params, + dim_label_emb=dim_label_emb, + n_iter_min=n_iter_min, + n_iter_print=n_iter_print, + patience=patience, + ) @staticmethod def name() -> str: @@ -113,98 +127,23 @@ def type() -> str: def hyperparameter_space(**kwargs: Any) -> List[Distribution]: raise NotImplementedError - def _anneal_lr(self, step): - frac_done = step / self.steps - lr = self.lr * (1 - frac_done) - for param_group in self.optimizer.param_groups: - param_group["lr"] = lr - - def _one_step(self, x, out_dict): - x = x.to(self.device) - for k in out_dict: - out_dict[k] = out_dict[k].long().to(self.device) - self.optimizer.zero_grad() - loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict) - loss = loss_multi + loss_gauss - loss.backward() - self.optimizer.step() - return loss_multi, loss_gauss - - def _update_ema(self, target_params, source_params, rate=0.999): - """ - Update target parameters to be closer to those of source parameters using - an exponential moving average. - :param target_params: the target parameter sequence. - :param source_params: the source parameter sequence. - :param rate: the EMA rate (closer to 1 means slower). - """ - for targ, src in zip(target_params, source_params): - targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) - - def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": - # TODO: add parameters of TabularEncoder - self.encoder = TabularEncoder().fit(X) - - if self.model_type == 'mlp': - self.model = MLPDiffusion(**self.model_params) - elif self.model_type == 'resnet': - self.model = ResNetDiffusion(**self.model_params) - else: - raise "Unknown model!" - - self.diffusion = GaussianMultinomialDiffusion( - denoise_fn=self.model, - num_numerical_features=self.encoder.n_features(), - gaussian_loss_type=self.gaussian_loss_type, - num_timesteps=self.num_timesteps, - scheduler=self.scheduler, - device=self.device - ).to(self.device) + def _fit(self, data: DataLoader, cond: pd.Series = None, **kwargs) -> "TabDDPMPlugin": + if self.is_classification: + assert cond is None + _, cond = data.unpack() + + if cond is not None: + cond = pd.Series(cond, index=data.index) + data = data.dataframe() + + # self.encoder = TabularEncoder().fit(X) - self.ema_model = deepcopy(self.model) - for param in self.ema_model.parameters(): - param.detach_() - - self.optimizer = torch.optim.AdamW( - self.diffusion.parameters(), lr=self.lr, weight_decay=self.weight_decay) - - # TODO: check data type of X - self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) + self.model.fit(data, cond, **kwargs) - curr_loss_multi = 0.0 - curr_loss_gauss = 0.0 - - curr_count = 0 - for step in range(self.n_iter): - x, out_dict = next(X) - out_dict = {'y': out_dict} - batch_loss_multi, batch_loss_gauss = self._one_step(x, out_dict) - - self._anneal_lr(step) - - curr_count += len(x) - curr_loss_multi += batch_loss_multi.item() * len(x) - curr_loss_gauss += batch_loss_gauss.item() * len(x) - - if (step + 1) % self.log_interval == 0: - mloss = np.around(curr_loss_multi / curr_count, 4) - gloss = np.around(curr_loss_gauss / curr_count, 4) - if (step + 1) % self.print_interval == 0: - print(f'Step {(step + 1)}/{self.steps} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') - self.loss_history.loc[len(self.loss_history)] = [ - step + 1, mloss, gloss, mloss + gloss] - curr_count = 0 - curr_loss_gauss = 0.0 - curr_loss_multi = 0.0 - - self._update_ema(self.ema_model.parameters(), self.model.parameters()) - - return self - - def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader: - self.diffusion.eval() - # TODO: check self.model.sample_all - return self._safe_generate(self.diffusion.sample_all, count, syn_schema) - + def _generate(self, count: int, syn_schema: Schema, cond=None, **kwargs: Any) -> DataLoader: + def callback(count, cond): + sample, cond = self.model.generate(count, cond=cond) + return sample + return self._safe_generate(callback, count, syn_schema, cond=cond, **kwargs) plugin = TabDDPMPlugin From 0abdc010a15ee3abe4330ffcaf2c310baf7046d2 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 7 Mar 2023 17:35:29 +0100 Subject: [PATCH 04/95] handle discrete cols and label generation --- .../core/models/tabular_ddpm/__init__.py | 24 ++- .../gaussian_multinomial_diffsuion.py | 167 ++++++++---------- .../core/models/tabular_ddpm/modules.py | 1 + .../plugins/core/models/tabular_encoder.py | 15 +- src/synthcity/plugins/generic/plugin_ddpm.py | 18 +- src/synthcity/utils/dataframe.py | 20 ++- 6 files changed, 126 insertions(+), 119 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 46977c79..d4fa28e6 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -11,6 +11,7 @@ # synthcity absolute from synthcity.utils.constants import DEVICE +from synthcity.utils.dataframe import discrete_columns from synthcity.metrics.weighted_metrics import WeightedMetrics from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion # noqa @@ -65,12 +66,21 @@ def _update_ema(self, target_params, source_params, rate=0.999): def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): if cond is not None: - n_classes = len(np.unique(cond)) + n_labels = cond.nunique() else: - n_classes = 0 + n_labels = 0 + cat_cols = discrete_columns(X, return_counts=True) + ini_cols = X.columns + cat_cols, cat_counts = zip(*cat_cols) + # reorder the columns so that the categorical ones go to the end + X = X[np.hstack([X.columns[~X.keys().isin(cat_cols)], cat_cols])] + cur_cols = X.columns + # find the permutation from the reordered columns to the original ones + self._col_perm = np.argsort(cur_cols)[np.argsort(np.argsort(ini_cols))] + model_params = dict( - num_classes=n_classes, + num_classes=n_labels, is_y_cond=cond is not None, rtdl_params=self.rtdl_params, dim_t = self.dim_label_emb @@ -83,7 +93,8 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): self.diffusion = GaussianMultinomialDiffusion( model_type=self.model_type, model_params=model_params, - num_numerical_features=self.encoder.n_features(), + num_categorical_features=cat_counts, + num_numerical_features=X.shape[1]-len(cat_cols), gaussian_loss_type=self.gaussian_loss_type, num_timesteps=self.num_timesteps, scheduler=self.scheduler, @@ -137,5 +148,6 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): def generate(self, count: int, cond=None): self.diffusion.eval() - sample, out_dict = self.diffusion.sample_all(count) - return sample, out_dict['y'] + sample = self.diffusion.sample_all(count, cond).detach().cpu().numpy() + sample = sample[:, self._col_perm] + return sample diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 70dd0a9f..dfcbd00a 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -65,7 +65,7 @@ class GaussianMultinomialDiffusion(torch.nn.Module): def __init__( self, num_numerical_features, - num_classes, + num_categorical_features, model_type='mlp', model_params=None, num_timesteps=1000, @@ -85,12 +85,12 @@ def __init__( print('Computing the loss using the bound on _all_ timesteps.' ' This is expensive both in terms of memory and computation.') - self.num_numerical_features = num_numerical_features - self.num_classes = num_classes + self.num_numerics = num_numerical_features + self.num_classes = num_categorical_features self.num_classes_expanded = torch.from_numpy( - np.concatenate([num_classes[i].repeat(num_classes[i]) for i in range(len(num_classes))]) + np.concatenate([num_categorical_features[i].repeat(num_categorical_features[i]) for i in range(len(num_categorical_features))]) ).to(device) - self.dim_input = num_numerical_features + sum(self.num_classes) + self.dim_input = self.num_numerics + sum(self.num_classes) self.slices_for_classes = [np.arange(self.num_classes[0])] offsets = np.cumsum(self.num_classes) @@ -391,7 +391,7 @@ def q_pred(self, log_x_start, t): return log_probs - def predict_start(self, model_out, log_x_t, t, out_dict): + def predict_start(self, model_out, log_x_t): # model_out = self._denoise_fn(x_t, t.to(x_t.device), **out_dict) @@ -434,25 +434,25 @@ def q_posterior(self, log_x_start, log_x_t, t): return log_EV_xtmin_given_xt_given_xstart - def p_pred(self, model_out, log_x, t, out_dict): + def p_pred(self, model_out, log_x, t): if self.parametrization == 'x0': - log_x_recon = self.predict_start(model_out, log_x, t=t, out_dict=out_dict) + log_x_recon = self.predict_start(model_out, log_x) log_model_pred = self.q_posterior( log_x_start=log_x_recon, log_x_t=log_x, t=t) elif self.parametrization == 'direct': - log_model_pred = self.predict_start(model_out, log_x, t=t, out_dict=out_dict) + log_model_pred = self.predict_start(model_out, log_x) else: raise ValueError return log_model_pred @torch.no_grad() - def p_sample(self, model_out, log_x, t, out_dict): - model_log_prob = self.p_pred(model_out, log_x=log_x, t=t, out_dict=out_dict) + def p_sample(self, model_out, log_x, t): + model_log_prob = self.p_pred(model_out, log_x=log_x, t=t) out = self.log_sample_categorical(model_log_prob) return out @torch.no_grad() - def p_sample_loop(self, shape, out_dict): + def p_sample_loop(self, shape): device = self.log_alpha.device b = shape[0] @@ -460,12 +460,12 @@ def p_sample_loop(self, shape, out_dict): img = torch.randn(shape, device=device) for i in reversed(range(1, self.num_timesteps)): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), out_dict) + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) return img @torch.no_grad() - def _sample(self, image_size, out_dict, batch_size = 16): - return self.p_sample_loop((batch_size, 3, image_size, image_size), out_dict) + def _sample(self, image_size, batch_size = 16): + return self.p_sample_loop((batch_size, 3, image_size, image_size)) @torch.no_grad() def interpolate(self, x1, x2, t = None, lam = 0.5): @@ -502,7 +502,7 @@ def q_sample(self, log_x_start, t): return log_sample - def nll(self, log_x_start, out_dict): + def nll(self, log_x_start): b = log_x_start.size(0) device = log_x_start.device loss = 0 @@ -512,8 +512,7 @@ def nll(self, log_x_start, out_dict): kl = self.compute_Lt( log_x_start=log_x_start, log_x_t=self.q_sample(log_x_start=log_x_start, t=t_array), - t=t_array, - out_dict=out_dict) + t=t_array) loss += kl @@ -532,10 +531,10 @@ def kl_prior(self, log_x_start): kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) return sum_except_batch(kl_prior) - def compute_Lt(self, model_out, log_x_start, log_x_t, t, out_dict, detach_mean=False): + def compute_Lt(self, model_out, log_x_start, log_x_t, t, detach_mean=False): log_true_prob = self.q_posterior( log_x_start=log_x_start, log_x_t=log_x_t, t=t) - log_model_prob = self.p_pred(model_out, log_x=log_x_t, t=t, out_dict=out_dict) + log_model_prob = self.p_pred(model_out, log_x=log_x_t, t=t) if detach_mean: log_model_prob = log_model_prob.detach() @@ -574,11 +573,11 @@ def sample_time(self, b, device, method='uniform'): else: raise ValueError - def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt, out_dict): + def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt): if self.multinomial_loss_type == 'vb_stochastic': kl = self.compute_Lt( - model_out, log_x_start, log_x_t, t, out_dict + model_out, log_x_start, log_x_t, t ) kl_prior = self.kl_prior(log_x_start) # Upweigh loss term of the kl @@ -593,10 +592,13 @@ def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt, out_dict): else: raise ValueError() - def log_prob(self, x, out_dict): + #! Not used + def log_prob(self, x): b, device = x.size(0), x.device + if self.training: - return self._multinomial_loss(x, out_dict) + #! not enough arguments + return self._multinomial_loss(x) else: log_x_start = index_to_log_onehot(x, self.num_classes) @@ -604,7 +606,7 @@ def log_prob(self, x, out_dict): t, pt = self.sample_time(b, device, 'importance') kl = self.compute_Lt( - log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t, out_dict) + log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t) kl_prior = self.kl_prior(log_x_start) @@ -613,13 +615,13 @@ def log_prob(self, x, out_dict): return -loss - def mixed_loss(self, x, out_dict): + def mixed_loss(self, x, cond=None): b = x.shape[0] device = x.device t, pt = self.sample_time(b, device, 'uniform') - x_num = x[:, :self.num_numerical_features] - x_cat = x[:, self.num_numerical_features:] + x_num = x[:, :self.num_numerics] + x_cat = x[:, self.num_numerics:] x_num_t = x_num log_x_cat_t = x_cat @@ -634,18 +636,17 @@ def mixed_loss(self, x, out_dict): model_out = self.denoise_fn( x_in, - t, - **out_dict + t, y=cond ) - model_out_num = model_out[:, :self.num_numerical_features] - model_out_cat = model_out[:, self.num_numerical_features:] + model_out_num = model_out[:, :self.num_numerics] + model_out_cat = model_out[:, self.num_numerics:] loss_multi = torch.zeros((1,)).float() loss_gauss = torch.zeros((1,)).float() if x_cat.shape[1] > 0: - loss_multi = self._multinomial_loss(model_out_cat, log_x_cat, log_x_cat_t, t, pt, - out_dict) / len(self.num_classes) + loss_multi = self._multinomial_loss(model_out_cat, log_x_cat, log_x_cat_t, + t, pt) / len(self.num_classes) if x_num.shape[1] > 0: loss_gauss = self._gaussian_loss(model_out_num, x_num, x_num_t, t, noise) @@ -656,12 +657,12 @@ def mixed_loss(self, x, out_dict): return loss_multi.mean(), loss_gauss.mean() @torch.no_grad() - def mixed_elbo(self, x0, out_dict): + def mixed_elbo(self, x0, cond=None): b = x0.size(0) device = x0.device - x_num = x0[:, :self.num_numerical_features] - x_cat = x0[:, self.num_numerical_features:] + x_num = x0[:, :self.num_numerics] + x_cat = x0[:, self.num_numerics:] has_cat = x_cat.shape[1] > 0 if has_cat: log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes).to(device) @@ -685,12 +686,11 @@ def mixed_elbo(self, x0, out_dict): model_out = self.denoise_fn( torch.cat([x_num_t, log_x_cat_t], dim=1), - t_array, - **out_dict + t_array, y=cond ) - model_out_num = model_out[:, :self.num_numerical_features] - model_out_cat = model_out[:, self.num_numerical_features:] + model_out_num = model_out[:, :self.num_numerics] + model_out_cat = model_out[:, self.num_numerics:] kl = torch.tensor([0.0]) if has_cat: @@ -699,7 +699,6 @@ def mixed_elbo(self, x0, out_dict): log_x_start=log_x_cat, log_x_t=log_x_cat_t, t=t_array, - out_dict=out_dict ) out = self._vb_terms_bpd( @@ -794,7 +793,7 @@ def gaussian_ddim_sample( self, noise, T, - out_dict, + cond=None, eta=0.0 ): x = noise @@ -803,7 +802,7 @@ def gaussian_ddim_sample( for t in reversed(range(T)): print(f'Sample timestep {t:4d}', end='\r') t_array = (torch.ones(b, device=device) * t).long() - out_num = self.denoise_fn(x, t_array, **out_dict) + out_num = self.denoise_fn(x, t_array, y=cond) x = self.gaussian_ddim_step( out_num, x, @@ -850,14 +849,14 @@ def gaussian_ddim_reverse_sample( self, x, T, - out_dict, + cond=None ): b = x.shape[0] device = x.device for t in range(T): print(f'Reverse timestep {t:4d}', end='\r') t_array = (torch.ones(b, device=device) * t).long() - out_num = self.denoise_fn(x, t_array, **out_dict) + out_num = self.denoise_fn(x, t_array, y=cond) x = self.gaussian_ddim_reverse_step( out_num, x, @@ -875,11 +874,10 @@ def multinomial_ddim_step( model_out_cat, log_x_t, t, - out_dict, eta=0.0 ): # not ddim, essentially - log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t, t=t, out_dict=out_dict) + log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t) alpha_bar = extract(self.alphas_cumprod, t, log_x_t.shape) alpha_bar_prev = extract(self.alphas_cumprod_prev, t, log_x_t.shape) @@ -907,10 +905,10 @@ def multinomial_ddim_step( return out @torch.no_grad() - def sample_ddim(self, num_samples, y_dist=None): + def sample_ddim(self, num_samples, cond=None): b = num_samples device = self.log_alpha.device - z_norm = torch.randn((b, self.num_numerical_features), device=device) + z_norm = torch.randn((b, self.num_numerics), device=device) has_cat = self.num_classes[0] != 0 log_z = torch.zeros((b, 0), device=device).float() @@ -918,25 +916,24 @@ def sample_ddim(self, num_samples, y_dist=None): uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device) log_z = self.log_sample_categorical(uniform_logits) - y = torch.multinomial( - y_dist, - num_samples=b, - replacement=True - ) - out_dict = {'y': y.long().to(device)} + # y = torch.multinomial( + # cond, + # num_samples=b, + # replacement=True + # ) + # out_dict = {'y': y.long().to(device)} for i in reversed(range(0, self.num_timesteps)): print(f'Sample timestep {i:4d}', end='\r') t = torch.full((b,), i, device=device, dtype=torch.long) model_out = self.denoise_fn( torch.cat([z_norm, log_z], dim=1).float(), - t, - **out_dict + t, y=cond ) - model_out_num = model_out[:, :self.num_numerical_features] - model_out_cat = model_out[:, self.num_numerical_features:] + model_out_num = model_out[:, :self.num_numerics] + model_out_cat = model_out[:, self.num_numerics:] z_norm = self.gaussian_ddim_step(model_out_num, z_norm, t, clip_denoised=False) if has_cat: - log_z = self.multinomial_ddim_step(model_out_cat, log_z, t, out_dict) + log_z = self.multinomial_ddim_step(model_out_cat, log_z, t) print() z_ohe = torch.exp(log_z).round() @@ -944,14 +941,13 @@ def sample_ddim(self, num_samples, y_dist=None): if has_cat: z_cat = ohe_to_categories(z_ohe, self.num_classes) sample = torch.cat([z_norm, z_cat], dim=1).cpu() - return sample, out_dict + return sample @torch.no_grad() - def sample(self, num_samples, y_dist=None): - # TODO: handle y_dist=None + def sample(self, num_samples, cond=None): b = num_samples device = self.log_alpha.device - z_norm = torch.randn((b, self.num_numerical_features), device=device) + z_norm = torch.randn((b, self.num_numerics), device=device) has_cat = self.num_classes[0] != 0 log_z = torch.zeros((b, 0), device=device).float() @@ -959,25 +955,24 @@ def sample(self, num_samples, y_dist=None): uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device) log_z = self.log_sample_categorical(uniform_logits) - y = torch.multinomial( - y_dist, - num_samples=b, - replacement=True - ) - out_dict = {'y': y.long().to(device)} + # y = torch.multinomial( + # cond, + # num_samples=b, + # replacement=True + # ) + # out_dict = {'y': y.long().to(device)} for i in reversed(range(0, self.num_timesteps)): print(f'Sample timestep {i:4d}', end='\r') t = torch.full((b,), i, device=device, dtype=torch.long) model_out = self.denoise_fn( torch.cat([z_norm, log_z], dim=1).float(), - t, - **out_dict + t, y=cond ) - model_out_num = model_out[:, :self.num_numerical_features] - model_out_cat = model_out[:, self.num_numerical_features:] + model_out_num = model_out[:, :self.num_numerics] + model_out_cat = model_out[:, self.num_numerics:] z_norm = self.gaussian_p_sample(model_out_num, z_norm, t, clip_denoised=False)['sample'] if has_cat: - log_z = self.p_sample(model_out_cat, log_z, t, out_dict) + log_z = self.p_sample(model_out_cat, log_z, t=t) print() z_ohe = torch.exp(log_z).round() @@ -985,30 +980,22 @@ def sample(self, num_samples, y_dist=None): if has_cat: z_cat = ohe_to_categories(z_ohe, self.num_classes) sample = torch.cat([z_norm, z_cat], dim=1).cpu() - return sample, out_dict + return sample - def sample_all(self, num_samples, y_dist=None, max_batch_size=2000, ddim=False): + def sample_all(self, num_samples, cond=None, max_batch_size=2000, ddim=False): if ddim: print('Sample using DDIM.') sample_fn = self.sample_ddim else: sample_fn = self.sample - bs = np.diff(list(range(0, num_samples, max_batch_size)) + [num_samples]) - all_y = [] + bs = np.diff([*range(0, num_samples, max_batch_size), num_samples]) all_samples = [] for b in bs: - sample, out_dict = sample_fn(b, y_dist) - mask_nan = torch.any(sample.isnan(), dim=1) - sample = sample[~mask_nan] - if sample.shape[0] != b: + sample = sample_fn(b, cond) + if torch.any(sample.isnan()).item(): raise FoundNANsError - out_dict['y'] = out_dict['y'][~mask_nan] all_samples.append(sample) - all_y.append(out_dict['y'].cpu()) - - x_gen = torch.cat(all_samples, dim=0)[:num_samples] - y_gen = torch.cat(all_y, dim=0)[:num_samples] - return x_gen, y_gen + return torch.cat(all_samples, dim=0) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index a6164119..44c63884 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -460,6 +460,7 @@ def forward(self, x, timesteps, y=None): x = self.proj(x) + emb return self.mlp(x) + class ResNetDiffusion(nn.Module): def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t = 256): super().__init__() diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 33929dc7..638b5e6c 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -14,6 +14,7 @@ # synthcity absolute import synthcity.logger as log from synthcity.utils.serialization import dataframe_hash +from synthcity.utils.dataframe import discrete_columns as find_cat_cols # synthcity relative from .data_encoder import ContinuousDataEncoder @@ -103,18 +104,14 @@ def _fit_continuous(self, data: pd.Series) -> FeatureInfo: ) def fit( - self, raw_data: pd.Series, discrete_columns: Optional[List] = None + self, raw_data: pd.DataFrame, discrete_columns: Optional[List] = None ) -> "BinEncoder": """Fit the ``BinEncoder``. Fits a ``ContinuousDataEncoder`` for continuous columns """ if discrete_columns is None: - discrete_columns = [] - - for col in raw_data.columns: - if len(raw_data[col].unique()) < self.categorical_limit: - discrete_columns.append(col) + discrete_columns = find_cat_cols(raw_data, self.categorical_limit) self.output_dimensions = 0 @@ -247,11 +244,7 @@ def fit( This step also counts the #columns in matrix data and span information. """ if discrete_columns is None: - discrete_columns = [] - - for col in raw_data.columns: - if len(raw_data[col].unique()) < self.categorical_limit: - discrete_columns.append(col) + discrete_columns = find_cat_cols(raw_data, self.categorical_limit) self.output_dimensions = 0 diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 6ed9fd89..11371c11 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -131,19 +131,25 @@ def _fit(self, data: DataLoader, cond: pd.Series = None, **kwargs) -> "TabDDPMPl if self.is_classification: assert cond is None _, cond = data.unpack() - + self._labels, self._cond_dist = np.unique(cond, return_counts=True) + self._cond_dist /= self._cond_dist.sum() + if cond is not None: cond = pd.Series(cond, index=data.index) - data = data.dataframe() + # NOTE: should we include the target column in `data`? + data = data.dataframe() + # self.encoder = TabularEncoder().fit(X) self.model.fit(data, cond, **kwargs) def _generate(self, count: int, syn_schema: Schema, cond=None, **kwargs: Any) -> DataLoader: - def callback(count, cond): - sample, cond = self.model.generate(count, cond=cond) - return sample - return self._safe_generate(callback, count, syn_schema, cond=cond, **kwargs) + if self.is_classification and cond is None: + # randomly generate labels following the distribution of the training data + cond = np.random.choice(self._labels, size=count, p=self._cond_dist) + def callback(count, cond=cond): + return self.model.generate(count, cond=cond) + return self._safe_generate(callback, count, syn_schema, **kwargs) plugin = TabDDPMPlugin diff --git a/src/synthcity/utils/dataframe.py b/src/synthcity/utils/dataframe.py index 35d7226b..069b6eab 100644 --- a/src/synthcity/utils/dataframe.py +++ b/src/synthcity/utils/dataframe.py @@ -4,10 +4,18 @@ def constant_columns(dataframe: pd.DataFrame) -> list: """ - Drops constant value columns of pandas dataframe. + Find constant value columns in a pandas dataframe. """ - result = [] - for column in dataframe.columns: - if len(dataframe[column].unique()) == 1: - result.append(column) - return result + return discrete_columns(dataframe, 2) + + +def discrete_columns(dataframe: pd.DataFrame, + max_classes: int = 10, + return_counts=False) -> list: + """ + Find columns containing discrete values in a pandas dataframe. + """ + return [(col, cnt) if return_counts else col + for col, vals in dataframe.items() + for cnt in [vals.nunique()] + if cnt < max_classes] From 405a052f16e9cc5bed6774ab95aa33f248331ce1 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 7 Mar 2023 18:01:38 +0100 Subject: [PATCH 05/95] add hparam space and update tests of DDPM --- src/synthcity/plugins/generic/plugin_ddpm.py | 33 +++++++++++++++- tests/plugins/generic/test_ddpm.py | 40 ++++++++++++-------- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 11371c11..f2f07f80 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -71,7 +71,9 @@ def __init__( log_interval: int = 100, print_interval: int = 500, # model params - rtdl_params: Optional[dict] = None, # {'d_layers', 'dropout'} + num_layers: int = 3, + dim_hidden: int = 256, + dropout: float = 0.0, dim_label_emb: int = 128, # early stopping n_iter_min: int = 100, @@ -96,6 +98,10 @@ def __init__( self.is_classification = is_classification + rtdl_params = dict( + d_layers = [self.dim_hidden] * self.num_layers, + dropout = self.dropout + ) self.model = TabDDPM( n_iter=n_iter, lr=lr, @@ -125,7 +131,30 @@ def type() -> str: @staticmethod def hyperparameter_space(**kwargs: Any) -> List[Distribution]: - raise NotImplementedError + """ + Hyperparameter Search space (from the paper) + ---------------------------------------------- + Learning rate LogUniform[0.00001, 0.003] + Batch size Cat{256, 4096} + Diffusion timesteps Cat{100, 1000} + Training iterations Cat{5000, 10000, 20000} + Number of MLP layers Int{2, 4, 6, 8} + MLP width of layers Int{128, 256, 512, 1024} + Proportion of samples Float{0.25, 0.5, 1, 2, 4, 8} + ---------------------------------------------- + Dropout 0.0 + Scheduler cosine (Nichol, 2021) + Gaussian diffusion loss MSE + """ + return [ + # TODO: change to loguniform distribution + CategoricalDistribution(name="lr", choices=[1e-5, 1e-4, 1e-3, 2e-3, 3e-3]), + CategoricalDistribution(name="batch_size", choices=[256, 4096]), + CategoricalDistribution(name="num_timesteps", choices=[100, 1000]), + CategoricalDistribution(name="n_iter", choices=[5000, 10000, 20000]), + CategoricalDistribution(name="num_layers", choices=[2, 4, 6, 8]), + CategoricalDistribution(name="dim_hidden", choices=[128, 256, 512, 1024]), + ] def _fit(self, data: DataLoader, cond: pd.Series = None, **kwargs) -> "TabDDPMPlugin": if self.is_classification: diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index ec112f48..a398c000 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -13,7 +13,14 @@ from synthcity.plugins.generic.plugin_ddpm import plugin plugin_name = "ddpm" -plugin_args = {"n_iter": 100} +plugin_args = dict( + n_iter=100, + is_classification=True, + # rtdl_params=dict( + # d_layers=[256, 256], + # dropout=0.0 + # ) +) @pytest.mark.parametrize( @@ -37,13 +44,6 @@ def test_plugin_type(test_plugin: Plugin) -> None: assert test_plugin.type() == "generic" -@pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) -) -def test_plugin_hyperparams(test_plugin: Plugin) -> None: - assert len(test_plugin.hyperparameter_space()) == 9 - - @pytest.mark.parametrize( "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) ) @@ -52,8 +52,10 @@ def test_plugin_fit(test_plugin: Plugin) -> None: test_plugin.fit(GenericDataLoader(X)) -def test_plugin_generate() -> None: - test_plugin = plugin(n_layers_hidden=2, n_units_hidden=100, n_iter=50) +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_generate(test_plugin: Plugin) -> None: X = pd.DataFrame(load_iris()["data"]) test_plugin.fit(GenericDataLoader(X)) @@ -66,8 +68,10 @@ def test_plugin_generate() -> None: assert test_plugin.schema_includes(X_gen) -def test_plugin_generate_constraints() -> None: - test_plugin = plugin(n_layers_hidden=2, n_units_hidden=100, n_iter=50) +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_generate_constraints(test_plugin: Plugin) -> None: X = pd.DataFrame(load_iris()["data"]) test_plugin.fit(GenericDataLoader(X)) @@ -96,23 +100,29 @@ def test_plugin_generate_constraints() -> None: assert list(X_gen.columns) == list(X.columns) +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_hyperparams(test_plugin: Plugin) -> None: + assert len(test_plugin.hyperparameter_space()) == 6 + + def test_sample_hyperparams() -> None: for i in range(100): args = plugin.sample_hyperparameters() - assert plugin(**args) is not None @pytest.mark.slow @pytest.mark.parametrize("compress_dataset", [True, False]) -def test_eval_performance_nflow(compress_dataset: bool) -> None: +def test_eval_performance_ddpm(compress_dataset: bool) -> None: results = [] Xraw, y = load_iris(return_X_y=True, as_frame=True) Xraw["target"] = y X = GenericDataLoader(Xraw) - for retry in range(2): + for _ in range(2): test_plugin = plugin(n_iter=5000, compress_dataset=compress_dataset) evaluator = PerformanceEvaluatorXGB() From 0e36041c056a6642810cf3c3203d7d5e832dfd7f Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 7 Mar 2023 19:34:32 +0100 Subject: [PATCH 06/95] debug and test DDPM --- .../core/models/tabular_ddpm/.lib/__init__.py | 12 - .../core/models/tabular_ddpm/.lib/data.py | 718 ------------------ .../core/models/tabular_ddpm/.lib/deep.py | 168 ---- .../core/models/tabular_ddpm/.lib/env.py | 39 - .../core/models/tabular_ddpm/.lib/metrics.py | 158 ---- .../core/models/tabular_ddpm/.lib/util.py | 433 ----------- .../core/models/tabular_ddpm/.pipeline.py | 80 -- .../core/models/tabular_ddpm/.sample.py | 159 ---- .../core/models/tabular_ddpm/.train.py | 156 ---- .../plugins/core/models/tabular_ddpm/.tune.py | 127 ---- .../core/models/tabular_ddpm/.utils_train.py | 88 --- .../core/models/tabular_ddpm/README.md | 3 - .../core/models/tabular_ddpm/__init__.py | 20 +- .../gaussian_multinomial_diffsuion.py | 34 +- .../core/models/tabular_ddpm/requirements.txt | 15 - src/synthcity/plugins/generic/plugin_ddpm.py | 14 +- src/temp.py | 15 + tests/plugins/generic/test_ddpm.py | 1 + 18 files changed, 57 insertions(+), 2183 deletions(-) delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/__init__.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/data.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/deep.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/env.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/metrics.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.lib/util.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.pipeline.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.sample.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.train.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.tune.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/.utils_train.py delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/README.md delete mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/requirements.txt create mode 100644 src/temp.py diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/__init__.py deleted file mode 100644 index 54d6f6bb..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -from icecream import install - -torch.set_num_threads(1) -install() - -from . import env # noqa -from .data import * # noqa -from .deep import * # noqa -from .env import * # noqa -from .metrics import * # noqa -from .util import * # noqa diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/data.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/data.py deleted file mode 100644 index 912ce259..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/data.py +++ /dev/null @@ -1,718 +0,0 @@ -import hashlib -from collections import Counter -from copy import deepcopy -from dataclasses import astuple, dataclass, replace -from importlib.resources import path -from pathlib import Path -from typing import Any, Literal, Optional, Union, cast, Tuple, Dict, List - -import numpy as np -import pandas as pd -from sklearn.model_selection import train_test_split -from sklearn.pipeline import make_pipeline -import sklearn.preprocessing -import torch -import os -from category_encoders import LeaveOneOutEncoder -from sklearn.impute import SimpleImputer -from sklearn.preprocessing import StandardScaler -from scipy.spatial.distance import cdist - -from . import env, util -from .metrics import calculate_metrics as calculate_metrics_ -from .util import TaskType, load_json - -ArrayDict = Dict[str, np.ndarray] -TensorDict = Dict[str, torch.Tensor] - - -CAT_MISSING_VALUE = '__nan__' -CAT_RARE_VALUE = '__rare__' -Normalization = Literal['standard', 'quantile', 'minmax'] -NumNanPolicy = Literal['drop-rows', 'mean'] -CatNanPolicy = Literal['most_frequent'] -CatEncoding = Literal['one-hot', 'counter'] -YPolicy = Literal['default'] - - -class StandardScaler1d(StandardScaler): - def partial_fit(self, X, *args, **kwargs): - assert X.ndim == 1 - return super().partial_fit(X[:, None], *args, **kwargs) - - def transform(self, X, *args, **kwargs): - assert X.ndim == 1 - return super().transform(X[:, None], *args, **kwargs).squeeze(1) - - def inverse_transform(self, X, *args, **kwargs): - assert X.ndim == 1 - return super().inverse_transform(X[:, None], *args, **kwargs).squeeze(1) - - -def get_category_sizes(X: Union[torch.Tensor, np.ndarray]) -> List[int]: - XT = X.T.cpu().tolist() if isinstance(X, torch.Tensor) else X.T.tolist() - return [len(set(x)) for x in XT] - - -@dataclass(frozen=False) -class Dataset: - X_num: Optional[ArrayDict] - X_cat: Optional[ArrayDict] - y: ArrayDict - y_info: Dict[str, Any] - task_type: TaskType - n_classes: Optional[int] - - @classmethod - def from_dir(cls, dir_: Union[Path, str]) -> 'Dataset': - dir_ = Path(dir_) - splits = [k for k in ['train', 'val', 'test'] if dir_.joinpath(f'y_{k}.npy').exists()] - - def load(item) -> ArrayDict: - return { - x: cast(np.ndarray, np.load(dir_ / f'{item}_{x}.npy', allow_pickle=True)) # type: ignore[code] - for x in splits - } - - if Path(dir_ / 'info.json').exists(): - info = util.load_json(dir_ / 'info.json') - else: - info = None - return Dataset( - load('X_num') if dir_.joinpath('X_num_train.npy').exists() else None, - load('X_cat') if dir_.joinpath('X_cat_train.npy').exists() else None, - load('y'), - {}, - TaskType(info['task_type']), - info.get('n_classes'), - ) - - @property - def is_binclass(self) -> bool: - return self.task_type == TaskType.BINCLASS - - @property - def is_multiclass(self) -> bool: - return self.task_type == TaskType.MULTICLASS - - @property - def is_regression(self) -> bool: - return self.task_type == TaskType.REGRESSION - - @property - def n_num_features(self) -> int: - return 0 if self.X_num is None else self.X_num['train'].shape[1] - - @property - def n_cat_features(self) -> int: - return 0 if self.X_cat is None else self.X_cat['train'].shape[1] - - @property - def n_features(self) -> int: - return self.n_num_features + self.n_cat_features - - def size(self, part: Optional[str]) -> int: - return sum(map(len, self.y.values())) if part is None else len(self.y[part]) - - @property - def nn_output_dim(self) -> int: - if self.is_multiclass: - assert self.n_classes is not None - return self.n_classes - else: - return 1 - - def get_category_sizes(self, part: str) -> List[int]: - return [] if self.X_cat is None else get_category_sizes(self.X_cat[part]) - - def calculate_metrics( - self, - predictions: Dict[str, np.ndarray], - prediction_type: Optional[str], - ) -> Dict[str, Any]: - metrics = { - x: calculate_metrics_( - self.y[x], predictions[x], self.task_type, prediction_type, self.y_info - ) - for x in predictions - } - if self.task_type == TaskType.REGRESSION: - score_key = 'rmse' - score_sign = -1 - else: - score_key = 'accuracy' - score_sign = 1 - for part_metrics in metrics.values(): - part_metrics['score'] = score_sign * part_metrics[score_key] - return metrics - -def change_val(dataset: Dataset, val_size: float = 0.2): - # should be done before transformations - - y = np.concatenate([dataset.y['train'], dataset.y['val']], axis=0) - - ixs = np.arange(y.shape[0]) - if dataset.is_regression: - train_ixs, val_ixs = train_test_split(ixs, test_size=val_size, random_state=777) - else: - train_ixs, val_ixs = train_test_split(ixs, test_size=val_size, random_state=777, stratify=y) - - dataset.y['train'] = y[train_ixs] - dataset.y['val'] = y[val_ixs] - - if dataset.X_num is not None: - X_num = np.concatenate([dataset.X_num['train'], dataset.X_num['val']], axis=0) - dataset.X_num['train'] = X_num[train_ixs] - dataset.X_num['val'] = X_num[val_ixs] - - if dataset.X_cat is not None: - X_cat = np.concatenate([dataset.X_cat['train'], dataset.X_cat['val']], axis=0) - dataset.X_cat['train'] = X_cat[train_ixs] - dataset.X_cat['val'] = X_cat[val_ixs] - - return dataset - -def num_process_nans(dataset: Dataset, policy: Optional[NumNanPolicy]) -> Dataset: - assert dataset.X_num is not None - nan_masks = {k: np.isnan(v) for k, v in dataset.X_num.items()} - if not any(x.any() for x in nan_masks.values()): # type: ignore[code] - assert policy is None - return dataset - - assert policy is not None - if policy == 'drop-rows': - valid_masks = {k: ~v.any(1) for k, v in nan_masks.items()} - assert valid_masks[ - 'test' - ].all(), 'Cannot drop test rows, since this will affect the final metrics.' - new_data = {} - for data_name in ['X_num', 'X_cat', 'y']: - data_dict = getattr(dataset, data_name) - if data_dict is not None: - new_data[data_name] = { - k: v[valid_masks[k]] for k, v in data_dict.items() - } - dataset = replace(dataset, **new_data) - elif policy == 'mean': - new_values = np.nanmean(dataset.X_num['train'], axis=0) - X_num = deepcopy(dataset.X_num) - for k, v in X_num.items(): - num_nan_indices = np.where(nan_masks[k]) - v[num_nan_indices] = np.take(new_values, num_nan_indices[1]) - dataset = replace(dataset, X_num=X_num) - else: - assert util.raise_unknown('policy', policy) - return dataset - - -# Inspired by: https://github.com/yandex-research/rtdl/blob/a4c93a32b334ef55d2a0559a4407c8306ffeeaee/lib/data.py#L20 -def normalize( - X: ArrayDict, normalization: Normalization, seed: Optional[int], return_normalizer : bool = False -) -> ArrayDict: - X_train = X['train'] - if normalization == 'standard': - normalizer = sklearn.preprocessing.StandardScaler() - elif normalization == 'minmax': - normalizer = sklearn.preprocessing.MinMaxScaler() - elif normalization == 'quantile': - normalizer = sklearn.preprocessing.QuantileTransformer( - output_distribution='normal', - n_quantiles=max(min(X['train'].shape[0] // 30, 1000), 10), - subsample=1e9, - random_state=seed, - ) - # noise = 1e-3 - # if noise > 0: - # assert seed is not None - # stds = np.std(X_train, axis=0, keepdims=True) - # noise_std = noise / np.maximum(stds, noise) # type: ignore[code] - # X_train = X_train + noise_std * np.random.default_rng(seed).standard_normal( - # X_train.shape - # ) - else: - util.raise_unknown('normalization', normalization) - normalizer.fit(X_train) - if return_normalizer: - return {k: normalizer.transform(v) for k, v in X.items()}, normalizer - return {k: normalizer.transform(v) for k, v in X.items()} - - -def cat_process_nans(X: ArrayDict, policy: Optional[CatNanPolicy]) -> ArrayDict: - assert X is not None - nan_masks = {k: v == CAT_MISSING_VALUE for k, v in X.items()} - if any(x.any() for x in nan_masks.values()): # type: ignore[code] - if policy is None: - X_new = X - elif policy == 'most_frequent': - imputer = SimpleImputer(missing_values=CAT_MISSING_VALUE, strategy=policy) # type: ignore[code] - imputer.fit(X['train']) - X_new = {k: cast(np.ndarray, imputer.transform(v)) for k, v in X.items()} - else: - util.raise_unknown('categorical NaN policy', policy) - else: - assert policy is None - X_new = X - return X_new - - -def cat_drop_rare(X: ArrayDict, min_frequency: float) -> ArrayDict: - assert 0.0 < min_frequency < 1.0 - min_count = round(len(X['train']) * min_frequency) - X_new = {x: [] for x in X} - for column_idx in range(X['train'].shape[1]): - counter = Counter(X['train'][:, column_idx].tolist()) - popular_categories = {k for k, v in counter.items() if v >= min_count} - for part in X_new: - X_new[part].append( - [ - (x if x in popular_categories else CAT_RARE_VALUE) - for x in X[part][:, column_idx].tolist() - ] - ) - return {k: np.array(v).T for k, v in X_new.items()} - - -def cat_encode( - X: ArrayDict, - encoding: Optional[CatEncoding], - y_train: Optional[np.ndarray], - seed: Optional[int], - return_encoder : bool = False -) -> Tuple[ArrayDict, bool, Optional[Any]]: # (X, is_converted_to_numerical) - if encoding != 'counter': - y_train = None - - # Step 1. Map strings to 0-based ranges - - if encoding is None: - unknown_value = np.iinfo('int64').max - 3 - oe = sklearn.preprocessing.OrdinalEncoder( - handle_unknown='use_encoded_value', # type: ignore[code] - unknown_value=unknown_value, # type: ignore[code] - dtype='int64', # type: ignore[code] - ).fit(X['train']) - encoder = make_pipeline(oe) - encoder.fit(X['train']) - X = {k: encoder.transform(v) for k, v in X.items()} - max_values = X['train'].max(axis=0) - for part in X.keys(): - if part == 'train': continue - for column_idx in range(X[part].shape[1]): - X[part][X[part][:, column_idx] == unknown_value, column_idx] = ( - max_values[column_idx] + 1 - ) - if return_encoder: - return (X, False, encoder) - return (X, False) - - # Step 2. Encode. - - elif encoding == 'one-hot': - ohe = sklearn.preprocessing.OneHotEncoder( - handle_unknown='ignore', sparse=False, dtype=np.float32 # type: ignore[code] - ) - encoder = make_pipeline(ohe) - - # encoder.steps.append(('ohe', ohe)) - encoder.fit(X['train']) - X = {k: encoder.transform(v) for k, v in X.items()} - elif encoding == 'counter': - assert y_train is not None - assert seed is not None - loe = LeaveOneOutEncoder(sigma=0.1, random_state=seed, return_df=False) - encoder.steps.append(('loe', loe)) - encoder.fit(X['train'], y_train) - X = {k: encoder.transform(v).astype('float32') for k, v in X.items()} # type: ignore[code] - if not isinstance(X['train'], pd.DataFrame): - X = {k: v.values for k, v in X.items()} # type: ignore[code] - else: - util.raise_unknown('encoding', encoding) - - if return_encoder: - return X, True, encoder # type: ignore[code] - return (X, True) - - -def build_target( - y: ArrayDict, policy: Optional[YPolicy], task_type: TaskType -) -> Tuple[ArrayDict, Dict[str, Any]]: - info: Dict[str, Any] = {'policy': policy} - if policy is None: - pass - elif policy == 'default': - if task_type == TaskType.REGRESSION: - mean, std = float(y['train'].mean()), float(y['train'].std()) - y = {k: (v - mean) / std for k, v in y.items()} - info['mean'] = mean - info['std'] = std - else: - util.raise_unknown('policy', policy) - return y, info - - -@dataclass(frozen=True) -class Transformations: - seed: int = 0 - normalization: Optional[Normalization] = None - num_nan_policy: Optional[NumNanPolicy] = None - cat_nan_policy: Optional[CatNanPolicy] = None - cat_min_frequency: Optional[float] = None - cat_encoding: Optional[CatEncoding] = None - y_policy: Optional[YPolicy] = 'default' - - -def transform_dataset( - dataset: Dataset, - transformations: Transformations, - cache_dir: Optional[Path], - return_transforms: bool = False -) -> Dataset: - # WARNING: the order of transformations matters. Moreover, the current - # implementation is not ideal in that sense. - if cache_dir is not None: - transformations_md5 = hashlib.md5( - str(transformations).encode('utf-8') - ).hexdigest() - transformations_str = '__'.join(map(str, astuple(transformations))) - cache_path = ( - cache_dir / f'cache__{transformations_str}__{transformations_md5}.pickle' - ) - if cache_path.exists(): - cache_transformations, value = util.load_pickle(cache_path) - if transformations == cache_transformations: - print( - f"Using cached features: {cache_dir.name + '/' + cache_path.name}" - ) - return value - else: - raise RuntimeError(f'Hash collision for {cache_path}') - else: - cache_path = None - - if dataset.X_num is not None: - dataset = num_process_nans(dataset, transformations.num_nan_policy) - - num_transform = None - cat_transform = None - X_num = dataset.X_num - - if X_num is not None and transformations.normalization is not None: - X_num, num_transform = normalize( - X_num, - transformations.normalization, - transformations.seed, - return_normalizer=True - ) - num_transform = num_transform - - if dataset.X_cat is None: - assert transformations.cat_nan_policy is None - assert transformations.cat_min_frequency is None - # assert transformations.cat_encoding is None - X_cat = None - else: - X_cat = cat_process_nans(dataset.X_cat, transformations.cat_nan_policy) - if transformations.cat_min_frequency is not None: - X_cat = cat_drop_rare(X_cat, transformations.cat_min_frequency) - X_cat, is_num, cat_transform = cat_encode( - X_cat, - transformations.cat_encoding, - dataset.y['train'], - transformations.seed, - return_encoder=True - ) - if is_num: - X_num = ( - X_cat - if X_num is None - else {x: np.hstack([X_num[x], X_cat[x]]) for x in X_num} - ) - X_cat = None - - y, y_info = build_target(dataset.y, transformations.y_policy, dataset.task_type) - - dataset = replace(dataset, X_num=X_num, X_cat=X_cat, y=y, y_info=y_info) - dataset.num_transform = num_transform - dataset.cat_transform = cat_transform - - if cache_path is not None: - util.dump_pickle((transformations, dataset), cache_path) - # if return_transforms: - # return dataset, num_transform, cat_transform - return dataset - - -def build_dataset( - path: Union[str, Path], - transformations: Transformations, - cache: bool -) -> Dataset: - path = Path(path) - dataset = Dataset.from_dir(path) - return transform_dataset(dataset, transformations, path if cache else None) - - -def prepare_tensors( - dataset: Dataset, device: Union[str, torch.device] -) -> Tuple[Optional[TensorDict], Optional[TensorDict], TensorDict]: - X_num, X_cat, Y = ( - None if x is None else {k: torch.as_tensor(v) for k, v in x.items()} - for x in [dataset.X_num, dataset.X_cat, dataset.y] - ) - if device.type != 'cpu': - X_num, X_cat, Y = ( - None if x is None else {k: v.to(device) for k, v in x.items()} - for x in [X_num, X_cat, Y] - ) - assert X_num is not None - assert Y is not None - if not dataset.is_multiclass: - Y = {k: v.float() for k, v in Y.items()} - return X_num, X_cat, Y - -############### -## DataLoader## -############### - -class TabDataset(torch.utils.data.Dataset): - def __init__( - self, dataset : Dataset, split : Literal['train', 'val', 'test'] - ): - super().__init__() - - self.X_num = torch.from_numpy(dataset.X_num[split]) if dataset.X_num is not None else None - self.X_cat = torch.from_numpy(dataset.X_cat[split]) if dataset.X_cat is not None else None - self.y = torch.from_numpy(dataset.y[split]) - - assert self.y is not None - assert self.X_num is not None or self.X_cat is not None - - def __len__(self): - return len(self.y) - - def __getitem__(self, idx): - out_dict = { - 'y': self.y[idx].long() if self.y is not None else None, - } - - x = np.empty((0,)) - if self.X_num is not None: - x = self.X_num[idx] - if self.X_cat is not None: - x = torch.cat([x, self.X_cat[idx]], dim=0) - return x.float(), out_dict - -def prepare_dataloader( - dataset : Dataset, - split : str, - batch_size: int, -): - - torch_dataset = TabDataset(dataset, split) - loader = torch.utils.data.DataLoader( - torch_dataset, - batch_size=batch_size, - shuffle=(split == 'train'), - num_workers=1, - ) - while True: - yield from loader - -def prepare_torch_dataloader( - dataset : Dataset, - split : str, - shuffle : bool, - batch_size: int, -) -> torch.utils.data.DataLoader: - - torch_dataset = TabDataset(dataset, split) - loader = torch.utils.data.DataLoader(torch_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1) - - return loader - -def dataset_from_csv(paths : Dict[str, str], cat_features, target, T): - assert 'train' in paths - y = {} - X_num = {} - X_cat = {} if len(cat_features) else None - for split in paths.keys(): - df = pd.read_csv(paths[split]) - y[split] = df[target].to_numpy().astype(float) - if X_cat is not None: - X_cat[split] = df[cat_features].to_numpy().astype(str) - X_num[split] = df.drop(cat_features + [target], axis=1).to_numpy().astype(float) - - dataset = Dataset(X_num, X_cat, y, {}, None, len(np.unique(y['train']))) - return transform_dataset(dataset, T, None) - -class FastTensorDataLoader: - """ - A DataLoader-like object for a set of tensors that can be much faster than - TensorDataset + DataLoader because dataloader grabs individual indices of - the dataset and calls cat (slow). - Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 - """ - def __init__(self, *tensors, batch_size=32, shuffle=False): - """ - Initialize a FastTensorDataLoader. - :param *tensors: tensors to store. Must have the same length @ dim 0. - :param batch_size: batch size to load. - :param shuffle: if True, shuffle the data *in-place* whenever an - iterator is created out of this object. - :returns: A FastTensorDataLoader. - """ - assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) - self.tensors = tensors - - self.dataset_len = self.tensors[0].shape[0] - self.batch_size = batch_size - self.shuffle = shuffle - - # Calculate # batches - n_batches, remainder = divmod(self.dataset_len, self.batch_size) - if remainder > 0: - n_batches += 1 - self.n_batches = n_batches - def __iter__(self): - if self.shuffle: - r = torch.randperm(self.dataset_len) - self.tensors = [t[r] for t in self.tensors] - self.i = 0 - return self - - def __next__(self): - if self.i >= self.dataset_len: - raise StopIteration - batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors) - self.i += self.batch_size - return batch - - def __len__(self): - return self.n_batches - -def prepare_fast_dataloader( - D : Dataset, - split : str, - batch_size: int -): - if D.X_cat is not None: - if D.X_num is not None: - X = torch.from_numpy(np.concatenate([D.X_num[split], D.X_cat[split]], axis=1)).float() - else: - X = torch.from_numpy(D.X_cat[split]).float() - else: - X = torch.from_numpy(D.X_num[split]).float() - y = torch.from_numpy(D.y[split]) - dataloader = FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=(split=='train')) - while True: - yield from dataloader - -def prepare_fast_torch_dataloader( - D : Dataset, - split : str, - batch_size: int -): - if D.X_cat is not None: - X = torch.from_numpy(np.concatenate([D.X_num[split], D.X_cat[split]], axis=1)).float() - else: - X = torch.from_numpy(D.X_num[split]).float() - y = torch.from_numpy(D.y[split]) - dataloader = FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=(split=='train')) - return dataloader - -def round_columns(X_real, X_synth, columns): - for col in columns: - uniq = np.unique(X_real[:,col]) - dist = cdist(X_synth[:, col][:, np.newaxis].astype(float), uniq[:, np.newaxis].astype(float)) - X_synth[:, col] = uniq[dist.argmin(axis=1)] - return X_synth - -def concat_features(D : Dataset): - if D.X_num is None: - assert D.X_cat is not None - X = {k: pd.DataFrame(v, columns=range(D.n_features)) for k, v in D.X_cat.items()} - elif D.X_cat is None: - assert D.X_num is not None - X = {k: pd.DataFrame(v, columns=range(D.n_features)) for k, v in D.X_num.items()} - else: - X = { - part: pd.concat( - [ - pd.DataFrame(D.X_num[part], columns=range(D.n_num_features)), - pd.DataFrame( - D.X_cat[part], - columns=range(D.n_num_features, D.n_features), - ), - ], - axis=1, - ) - for part in D.y.keys() - } - - return X - -def concat_to_pd(X_num, X_cat, y): - if X_num is None: - return pd.concat([ - pd.DataFrame(X_cat, columns=list(range(X_cat.shape[1]))), - pd.DataFrame(y, columns=['y']) - ], axis=1) - if X_cat is not None: - return pd.concat([ - pd.DataFrame(X_num, columns=list(range(X_num.shape[1]))), - pd.DataFrame(X_cat, columns=list(range(X_num.shape[1], X_num.shape[1] + X_cat.shape[1]))), - pd.DataFrame(y, columns=['y']) - ], axis=1) - return pd.concat([ - pd.DataFrame(X_num, columns=list(range(X_num.shape[1]))), - pd.DataFrame(y, columns=['y']) - ], axis=1) - -def read_pure_data(path, split='train'): - y = np.load(os.path.join(path, f'y_{split}.npy'), allow_pickle=True) - X_num = None - X_cat = None - if os.path.exists(os.path.join(path, f'X_num_{split}.npy')): - X_num = np.load(os.path.join(path, f'X_num_{split}.npy'), allow_pickle=True) - if os.path.exists(os.path.join(path, f'X_cat_{split}.npy')): - X_cat = np.load(os.path.join(path, f'X_cat_{split}.npy'), allow_pickle=True) - - return X_num, X_cat, y - -def read_changed_val(path, val_size=0.2): - path = Path(path) - X_num_train, X_cat_train, y_train = read_pure_data(path, 'train') - X_num_val, X_cat_val, y_val = read_pure_data(path, 'val') - is_regression = load_json(path / 'info.json')['task_type'] == 'regression' - - y = np.concatenate([y_train, y_val], axis=0) - - ixs = np.arange(y.shape[0]) - if is_regression: - train_ixs, val_ixs = train_test_split(ixs, test_size=val_size, random_state=777) - else: - train_ixs, val_ixs = train_test_split(ixs, test_size=val_size, random_state=777, stratify=y) - y_train = y[train_ixs] - y_val = y[val_ixs] - - if X_num_train is not None: - X_num = np.concatenate([X_num_train, X_num_val], axis=0) - X_num_train = X_num[train_ixs] - X_num_val = X_num[val_ixs] - - if X_cat_train is not None: - X_cat = np.concatenate([X_cat_train, X_cat_val], axis=0) - X_cat_train = X_cat[train_ixs] - X_cat_val = X_cat[val_ixs] - - return X_num_train, X_cat_train, y_train, X_num_val, X_cat_val, y_val - -############# - -def load_dataset_info(dataset_dir_name: str) -> Dict[str, Any]: - path = Path("data/" + dataset_dir_name) - info = util.load_json(path / 'info.json') - info['size'] = info['train_size'] + info['val_size'] + info['test_size'] - info['n_features'] = info['n_num_features'] + info['n_cat_features'] - info['path'] = path - return info diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/deep.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/deep.py deleted file mode 100644 index aeed3e2a..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/deep.py +++ /dev/null @@ -1,168 +0,0 @@ -import statistics -from dataclasses import dataclass -from typing import Any, Callable, Literal, cast - -import rtdl -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import zero -from torch import Tensor - -from .util import TaskType - - -def cos_sin(x: Tensor) -> Tensor: - return torch.cat([torch.cos(x), torch.sin(x)], -1) - - -@dataclass -class PeriodicOptions: - n: int # the output size is 2 * n - sigma: float - trainable: bool - initialization: Literal['log-linear', 'normal'] - - -class Periodic(nn.Module): - def __init__(self, n_features: int, options: PeriodicOptions) -> None: - super().__init__() - if options.initialization == 'log-linear': - coefficients = options.sigma ** (torch.arange(options.n) / options.n) - coefficients = coefficients[None].repeat(n_features, 1) - else: - assert options.initialization == 'normal' - coefficients = torch.normal(0.0, options.sigma, (n_features, options.n)) - if options.trainable: - self.coefficients = nn.Parameter(coefficients) # type: ignore[code] - else: - self.register_buffer('coefficients', coefficients) - - def forward(self, x: Tensor) -> Tensor: - assert x.ndim == 2 - return cos_sin(2 * torch.pi * self.coefficients[None] * x[..., None]) - - -def get_n_parameters(m: nn.Module): - return sum(x.numel() for x in m.parameters() if x.requires_grad) - - -def get_loss_fn(task_type: TaskType) -> Callable[..., Tensor]: - return ( - F.binary_cross_entropy_with_logits - if task_type == TaskType.BINCLASS - else F.cross_entropy - if task_type == TaskType.MULTICLASS - else F.mse_loss - ) - - -def default_zero_weight_decay_condition(module_name, module, parameter_name, parameter): - del module_name, parameter - return parameter_name.endswith('bias') or isinstance( - module, - ( - nn.BatchNorm1d, - nn.LayerNorm, - nn.InstanceNorm1d, - rtdl.CLSToken, - rtdl.NumericalFeatureTokenizer, - rtdl.CategoricalFeatureTokenizer, - Periodic, - ), - ) - - -def split_parameters_by_weight_decay( - model: nn.Module, zero_weight_decay_condition=default_zero_weight_decay_condition -) -> list[dict[str, Any]]: - parameters_info = {} - for module_name, module in model.named_modules(): - for parameter_name, parameter in module.named_parameters(): - full_parameter_name = ( - f'{module_name}.{parameter_name}' if module_name else parameter_name - ) - parameters_info.setdefault(full_parameter_name, ([], parameter))[0].append( - zero_weight_decay_condition( - module_name, module, parameter_name, parameter - ) - ) - params_with_wd = {'params': []} - params_without_wd = {'params': [], 'weight_decay': 0.0} - for full_parameter_name, (results, parameter) in parameters_info.items(): - (params_without_wd if any(results) else params_with_wd)['params'].append( - parameter - ) - return [params_with_wd, params_without_wd] - - -def make_optimizer( - config: dict[str, Any], - parameter_groups, -) -> optim.Optimizer: - if config['optimizer'] == 'FT-Transformer-default': - return optim.AdamW(parameter_groups, lr=1e-4, weight_decay=1e-5) - return getattr(optim, config['optimizer'])( - parameter_groups, - **{x: config[x] for x in ['lr', 'weight_decay', 'momentum'] if x in config}, - ) - - -def get_lr(optimizer: optim.Optimizer) -> float: - return next(iter(optimizer.param_groups))['lr'] - - -def is_oom_exception(err: RuntimeError) -> bool: - return any( - x in str(err) - for x in [ - 'CUDA out of memory', - 'CUBLAS_STATUS_ALLOC_FAILED', - 'CUDA error: out of memory', - ] - ) - - -def train_with_auto_virtual_batch( - optimizer, - loss_fn, - step, - batch, - chunk_size: int, -) -> tuple[Tensor, int]: - batch_size = len(batch) - random_state = zero.random.get_state() - loss = None - while chunk_size != 0: - try: - zero.random.set_state(random_state) - optimizer.zero_grad() - if batch_size <= chunk_size: - loss = loss_fn(*step(batch)) - loss.backward() - else: - loss = None - for chunk in zero.iter_batches(batch, chunk_size): - chunk_loss = loss_fn(*step(chunk)) - chunk_loss = chunk_loss * (len(chunk) / batch_size) - chunk_loss.backward() - if loss is None: - loss = chunk_loss.detach() - else: - loss += chunk_loss.detach() - except RuntimeError as err: - if not is_oom_exception(err): - raise - chunk_size //= 2 - else: - break - if not chunk_size: - raise RuntimeError('Not enough memory even for batch_size=1') - optimizer.step() - return cast(Tensor, loss), chunk_size - - -def process_epoch_losses(losses: list[Tensor]) -> tuple[list[float], float]: - losses_ = torch.stack(losses).tolist() - return losses_, statistics.mean(losses_) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/env.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/env.py deleted file mode 100644 index 64be89d7..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/env.py +++ /dev/null @@ -1,39 +0,0 @@ -""" -Have not used in TabDDPM project. -""" - -import datetime -import os -import shutil -import typing as ty -from pathlib import Path - -PROJ = Path('tab-ddpm/').absolute().resolve() -EXP = PROJ / 'exp' -DATA = PROJ / 'data' - - -def get_path(path: ty.Union[str, Path]) -> Path: - if isinstance(path, str): - path = Path(path) - if not path.is_absolute(): - path = PROJ / path - return path.resolve() - - -def get_relative_path(path: ty.Union[str, Path]) -> Path: - return get_path(path).relative_to(PROJ) - - -def duplicate_path( - src: ty.Union[str, Path], alternative_project_dir: ty.Union[str, Path] -) -> None: - src = get_path(src) - alternative_project_dir = get_path(alternative_project_dir) - dst = alternative_project_dir / src.relative_to(PROJ) - dst.parent.mkdir(parents=True, exist_ok=True) - if dst.exists(): - dst = dst.with_name( - dst.name + '_' + datetime.datetime.now().strftime('%Y%m%dT%H%M%S') - ) - (shutil.copytree if src.is_dir() else shutil.copyfile)(src, dst) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/metrics.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/metrics.py deleted file mode 100644 index bdcac817..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/metrics.py +++ /dev/null @@ -1,158 +0,0 @@ -import enum -from typing import Any, Optional, Tuple, Dict, Union, cast -from functools import partial - -import numpy as np -import scipy.special -import sklearn.metrics as skm - -from . import util -from .util import TaskType - - -class PredictionType(enum.Enum): - LOGITS = 'logits' - PROBS = 'probs' - -class MetricsReport: - def __init__(self, report: dict, task_type: TaskType): - self._res = {k: {} for k in report.keys()} - if task_type in (TaskType.BINCLASS, TaskType.MULTICLASS): - self._metrics_names = ["acc", "f1"] - for k in report.keys(): - self._res[k]["acc"] = report[k]["accuracy"] - self._res[k]["f1"] = report[k]["macro avg"]["f1-score"] - if task_type == TaskType.BINCLASS: - self._res[k]["roc_auc"] = report[k]["roc_auc"] - self._metrics_names.append("roc_auc") - - elif task_type == TaskType.REGRESSION: - self._metrics_names = ["r2", "rmse"] - for k in report.keys(): - self._res[k]["r2"] = report[k]["r2"] - self._res[k]["rmse"] = report[k]["rmse"] - else: - raise "Unknown TaskType!" - - def get_splits_names(self) -> list[str]: - return self._res.keys() - - def get_metrics_names(self) -> list[str]: - return self._metrics_names - - def get_metric(self, split: str, metric: str) -> float: - return self._res[split][metric] - - def get_val_score(self) -> float: - return self._res["val"]["r2"] if "r2" in self._res["val"] else self._res["val"]["f1"] - - def get_test_score(self) -> float: - return self._res["test"]["r2"] if "r2" in self._res["test"] else self._res["test"]["f1"] - - def print_metrics(self) -> None: - res = { - "val": {k: np.around(self._res["val"][k], 4) for k in self._res["val"]}, - "test": {k: np.around(self._res["test"][k], 4) for k in self._res["test"]} - } - - print("*"*100) - print("[val]") - print(res["val"]) - print("[test]") - print(res["test"]) - - return res - -class SeedsMetricsReport: - def __init__(self): - self._reports = [] - - def add_report(self, report: MetricsReport) -> None: - self._reports.append(report) - - def get_mean_std(self) -> dict: - res = {k: {} for k in ["train", "val", "test"]} - for split in self._reports[0].get_splits_names(): - for metric in self._reports[0].get_metrics_names(): - res[split][metric] = [x.get_metric(split, metric) for x in self._reports] - - agg_res = {k: {} for k in ["train", "val", "test"]} - for split in self._reports[0].get_splits_names(): - for metric in self._reports[0].get_metrics_names(): - for k, f in [("count", len), ("mean", np.mean), ("std", np.std)]: - agg_res[split][f"{metric}-{k}"] = f(res[split][metric]) - self._res = res - self._agg_res = agg_res - - return agg_res - - def print_result(self) -> dict: - res = {split: {k: float(np.around(self._agg_res[split][k], 4)) for k in self._agg_res[split]} for split in ["val", "test"]} - print("="*100) - print("EVAL RESULTS:") - print("[val]") - print(res["val"]) - print("[test]") - print(res["test"]) - print("="*100) - return res - -def calculate_rmse( - y_true: np.ndarray, y_pred: np.ndarray, std: Optional[float] -) -> float: - rmse = skm.mean_squared_error(y_true, y_pred) ** 0.5 - if std is not None: - rmse *= std - return rmse - - -def _get_labels_and_probs( - y_pred: np.ndarray, task_type: TaskType, prediction_type: Optional[PredictionType] -) -> Tuple[np.ndarray, Optional[np.ndarray]]: - assert task_type in (TaskType.BINCLASS, TaskType.MULTICLASS) - - if prediction_type is None: - return y_pred, None - - if prediction_type == PredictionType.LOGITS: - probs = ( - scipy.special.expit(y_pred) - if task_type == TaskType.BINCLASS - else scipy.special.softmax(y_pred, axis=1) - ) - elif prediction_type == PredictionType.PROBS: - probs = y_pred - else: - util.raise_unknown('prediction_type', prediction_type) - - assert probs is not None - labels = np.round(probs) if task_type == TaskType.BINCLASS else probs.argmax(axis=1) - return labels.astype('int64'), probs - - -def calculate_metrics( - y_true: np.ndarray, - y_pred: np.ndarray, - task_type: Union[str, TaskType], - prediction_type: Optional[Union[str, PredictionType]], - y_info: Dict[str, Any], -) -> Dict[str, Any]: - # Example: calculate_metrics(y_true, y_pred, 'binclass', 'logits', {}) - task_type = TaskType(task_type) - if prediction_type is not None: - prediction_type = PredictionType(prediction_type) - - if task_type == TaskType.REGRESSION: - assert prediction_type is None - assert 'std' in y_info - rmse = calculate_rmse(y_true, y_pred, y_info['std']) - r2 = skm.r2_score(y_true, y_pred) - result = {'rmse': rmse, 'r2': r2} - else: - labels, probs = _get_labels_and_probs(y_pred, task_type, prediction_type) - result = cast( - Dict[str, Any], skm.classification_report(y_true, labels, output_dict=True) - ) - if task_type == TaskType.BINCLASS: - result['roc_auc'] = skm.roc_auc_score(y_true, probs) - return result diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/util.py b/src/synthcity/plugins/core/models/tabular_ddpm/.lib/util.py deleted file mode 100644 index 75e05c9c..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.lib/util.py +++ /dev/null @@ -1,433 +0,0 @@ -import argparse -import atexit -import enum -import json -import os -import pickle -import shutil -import sys -import time -import uuid -from copy import deepcopy -from dataclasses import asdict, fields, is_dataclass -from pathlib import Path -from pprint import pprint -from typing import Any, Callable, List, Dict, Type, Optional, Tuple, TypeVar, Union, cast, get_args, get_origin - -import __main__ -import numpy as np -import tomli -import tomli_w -import torch -import zero - -from . import env - -RawConfig = Dict[str, Any] -Report = Dict[str, Any] -T = TypeVar('T') - - -class Part(enum.Enum): - TRAIN = 'train' - VAL = 'val' - TEST = 'test' - - def __str__(self) -> str: - return self.value - - -class TaskType(enum.Enum): - BINCLASS = 'binclass' - MULTICLASS = 'multiclass' - REGRESSION = 'regression' - - def __str__(self) -> str: - return self.value - - -class Timer(zero.Timer): - @classmethod - def launch(cls) -> 'Timer': - timer = cls() - timer.run() - return timer - - -def update_training_log(training_log, data, metrics): - def _update(log_part, data_part): - for k, v in data_part.items(): - if isinstance(v, dict): - _update(log_part.setdefault(k, {}), v) - elif isinstance(v, list): - log_part.setdefault(k, []).extend(v) - else: - log_part.setdefault(k, []).append(v) - - _update(training_log, data) - transposed_metrics = {} - for part, part_metrics in metrics.items(): - for metric_name, value in part_metrics.items(): - transposed_metrics.setdefault(metric_name, {})[part] = value - _update(training_log, transposed_metrics) - - -def raise_unknown(unknown_what: str, unknown_value: Any): - raise ValueError(f'Unknown {unknown_what}: {unknown_value}') - - -def _replace(data, condition, value): - def do(x): - if isinstance(x, dict): - return {k: do(v) for k, v in x.items()} - elif isinstance(x, list): - return [do(y) for y in x] - else: - return value if condition(x) else x - - return do(data) - - -_CONFIG_NONE = '__none__' - - -def unpack_config(config: RawConfig) -> RawConfig: - config = cast(RawConfig, _replace(config, lambda x: x == _CONFIG_NONE, None)) - return config - - -def pack_config(config: RawConfig) -> RawConfig: - config = cast(RawConfig, _replace(config, lambda x: x is None, _CONFIG_NONE)) - return config - - -def load_config(path: Union[Path, str]) -> Any: - with open(path, 'rb') as f: - return unpack_config(tomli.load(f)) - - -def dump_config(config: Any, path: Union[Path, str]) -> None: - with open(path, 'wb') as f: - tomli_w.dump(pack_config(config), f) - # check that there are no bugs in all these "pack/unpack" things - assert config == load_config(path) - - -def load_json(path: Union[Path, str], **kwargs) -> Any: - return json.loads(Path(path).read_text(), **kwargs) - - -def dump_json(x: Any, path: Union[Path, str], **kwargs) -> None: - kwargs.setdefault('indent', 4) - Path(path).write_text(json.dumps(x, **kwargs) + '\n') - - -def load_pickle(path: Union[Path, str], **kwargs) -> Any: - return pickle.loads(Path(path).read_bytes(), **kwargs) - - -def dump_pickle(x: Any, path: Union[Path, str], **kwargs) -> None: - Path(path).write_bytes(pickle.dumps(x, **kwargs)) - - -def load(path: Union[Path, str], **kwargs) -> Any: - return globals()[f'load_{Path(path).suffix[1:]}'](Path(path), **kwargs) - - -def dump(x: Any, path: Union[Path, str], **kwargs) -> Any: - return globals()[f'dump_{Path(path).suffix[1:]}'](x, Path(path), **kwargs) - - -def _get_output_item_path( - path: Union[str, Path], filename: str, must_exist: bool -) -> Path: - path = env.get_path(path) - if path.suffix == '.toml': - path = path.with_suffix('') - if path.is_dir(): - path = path / filename - else: - assert path.name == filename - assert path.parent.exists() - if must_exist: - assert path.exists() - return path - - -def load_report(path: Path) -> Report: - return load_json(_get_output_item_path(path, 'report.json', True)) - - -def dump_report(report: dict, path: Path) -> None: - dump_json(report, _get_output_item_path(path, 'report.json', False)) - - -def load_predictions(path: Path) -> Dict[str, np.ndarray]: - with np.load(_get_output_item_path(path, 'predictions.npz', True)) as predictions: - return {x: predictions[x] for x in predictions} - - -def dump_predictions(predictions: Dict[str, np.ndarray], path: Path) -> None: - np.savez(_get_output_item_path(path, 'predictions.npz', False), **predictions) - - -def dump_metrics(metrics: Dict[str, Any], path: Path) -> None: - dump_json(metrics, _get_output_item_path(path, 'metrics.json', False)) - - -def load_checkpoint(path: Path, *args, **kwargs) -> Dict[str, np.ndarray]: - return torch.load( - _get_output_item_path(path, 'checkpoint.pt', True), *args, **kwargs - ) - - -def get_device() -> torch.device: - if torch.cuda.is_available(): - assert os.environ.get('CUDA_VISIBLE_DEVICES') is not None - return torch.device('cuda:0') - else: - return torch.device('cpu') - - -def _print_sep(c, size=100): - print(c * size) - - -def start( - config_cls: Type[T] = RawConfig, - argv: Optional[List[str]] = None, - patch_raw_config: Optional[Callable[[RawConfig], None]] = None, -) -> Tuple[T, Path, Report]: # config # output dir # report - parser = argparse.ArgumentParser() - parser.add_argument('config', metavar='FILE') - parser.add_argument('--force', action='store_true') - parser.add_argument('--continue', action='store_true', dest='continue_') - if argv is None: - program = __main__.__file__ - args = parser.parse_args() - else: - program = argv[0] - try: - args = parser.parse_args(argv[1:]) - except Exception: - print( - 'Failed to parse `argv`.' - ' Remember that the first item of `argv` must be the path (relative to' - ' the project root) to the script/notebook.' - ) - raise - args = parser.parse_args(argv) - - snapshot_dir = os.environ.get('SNAPSHOT_PATH') - if snapshot_dir and Path(snapshot_dir).joinpath('CHECKPOINTS_RESTORED').exists(): - assert args.continue_ - - config_path = env.get_path(args.config) - output_dir = config_path.with_suffix('') - _print_sep('=') - print(f'[output] {output_dir}') - _print_sep('=') - - assert config_path.exists() - raw_config = load_config(config_path) - if patch_raw_config is not None: - patch_raw_config(raw_config) - if is_dataclass(config_cls): - config = from_dict(config_cls, raw_config) - full_raw_config = asdict(config) - else: - assert config_cls is dict - full_raw_config = config = raw_config - full_raw_config = asdict(config) - - if output_dir.exists(): - if args.force: - print('Removing the existing output and creating a new one...') - shutil.rmtree(output_dir) - output_dir.mkdir() - elif not args.continue_: - backup_output(output_dir) - print('The output directory already exists. Done!\n') - sys.exit() - elif output_dir.joinpath('DONE').exists(): - backup_output(output_dir) - print('The "DONE" file already exists. Done!') - sys.exit() - else: - print('Continuing with the existing output...') - else: - print('Creating the output...') - output_dir.mkdir() - - report = { - 'program': str(env.get_relative_path(program)), - 'environment': {}, - 'config': full_raw_config, - } - if torch.cuda.is_available(): # type: ignore[code] - report['environment'].update( - { - 'CUDA_VISIBLE_DEVICES': os.environ.get('CUDA_VISIBLE_DEVICES'), - 'gpus': zero.hardware.get_gpus_info(), - 'torch.version.cuda': torch.version.cuda, - 'torch.backends.cudnn.version()': torch.backends.cudnn.version(), # type: ignore[code] - 'torch.cuda.nccl.version()': torch.cuda.nccl.version(), # type: ignore[code] - } - ) - dump_report(report, output_dir) - dump_json(raw_config, output_dir / 'raw_config.json') - _print_sep('-') - pprint(full_raw_config, width=100) - _print_sep('-') - return cast(config_cls, config), output_dir, report - - -_LAST_SNAPSHOT_TIME = None - - -def backup_output(output_dir: Path) -> None: - backup_dir = os.environ.get('TMP_OUTPUT_PATH') - snapshot_dir = os.environ.get('SNAPSHOT_PATH') - if backup_dir is None: - assert snapshot_dir is None - return - assert snapshot_dir is not None - - try: - relative_output_dir = output_dir.relative_to(env.PROJ) - except ValueError: - return - - for dir_ in [backup_dir, snapshot_dir]: - new_output_dir = dir_ / relative_output_dir - prev_backup_output_dir = new_output_dir.with_name(new_output_dir.name + '_prev') - new_output_dir.parent.mkdir(exist_ok=True, parents=True) - if new_output_dir.exists(): - new_output_dir.rename(prev_backup_output_dir) - shutil.copytree(output_dir, new_output_dir) - # the case for evaluate.py which automatically creates configs - if output_dir.with_suffix('.toml').exists(): - shutil.copyfile( - output_dir.with_suffix('.toml'), new_output_dir.with_suffix('.toml') - ) - if prev_backup_output_dir.exists(): - shutil.rmtree(prev_backup_output_dir) - - global _LAST_SNAPSHOT_TIME - if _LAST_SNAPSHOT_TIME is None or time.time() - _LAST_SNAPSHOT_TIME > 10 * 60: - import nirvana_dl.snapshot # type: ignore[code] - - nirvana_dl.snapshot.dump_snapshot() - _LAST_SNAPSHOT_TIME = time.time() - print('The snapshot was saved!') - - -def _get_scores(metrics: Dict[str, Dict[str, Any]]) -> Optional[Dict[str, float]]: - return ( - {k: v['score'] for k, v in metrics.items()} - if 'score' in next(iter(metrics.values())) - else None - ) - - -def format_scores(metrics: Dict[str, Dict[str, Any]]) -> str: - return ' '.join( - f"[{x}] {metrics[x]['score']:.3f}" - for x in ['test', 'val', 'train'] - if x in metrics - ) - - -def finish(output_dir: Path, report: dict) -> None: - print() - _print_sep('=') - - metrics = report.get('metrics') - if metrics is not None: - scores = _get_scores(metrics) - if scores is not None: - dump_json(scores, output_dir / 'scores.json') - print(format_scores(metrics)) - _print_sep('-') - - dump_report(report, output_dir) - json_output_path = os.environ.get('JSON_OUTPUT_FILE') - if json_output_path: - try: - key = str(output_dir.relative_to(env.PROJ)) - except ValueError: - pass - else: - json_output_path = Path(json_output_path) - try: - json_data = json.loads(json_output_path.read_text()) - except (FileNotFoundError, json.decoder.JSONDecodeError): - json_data = {} - json_data[key] = load_json(output_dir / 'report.json') - json_output_path.write_text(json.dumps(json_data, indent=4)) - shutil.copyfile( - json_output_path, - os.path.join(os.environ['SNAPSHOT_PATH'], 'json_output.json'), - ) - - output_dir.joinpath('DONE').touch() - backup_output(output_dir) - print(f'Done! | {report.get("time")} | {output_dir}') - _print_sep('=') - print() - - -def from_dict(datacls: Type[T], data: dict) -> T: - assert is_dataclass(datacls) - data = deepcopy(data) - for field in fields(datacls): - if field.name not in data: - continue - if is_dataclass(field.type): - data[field.name] = from_dict(field.type, data[field.name]) - elif ( - get_origin(field.type) is Union - and len(get_args(field.type)) == 2 - and get_args(field.type)[1] is type(None) - and is_dataclass(get_args(field.type)[0]) - ): - if data[field.name] is not None: - data[field.name] = from_dict(get_args(field.type)[0], data[field.name]) - return datacls(**data) - - -def replace_factor_with_value( - config: RawConfig, - key: str, - reference_value: int, - bounds: Tuple[float, float], -) -> None: - factor_key = key + '_factor' - if factor_key not in config: - assert key in config - else: - assert key not in config - factor = config.pop(factor_key) - assert bounds[0] <= factor <= bounds[1] - config[key] = int(factor * reference_value) - - -def get_temporary_copy(path: Union[str, Path]) -> Path: - path = env.get_path(path) - assert not path.is_dir() and not path.is_symlink() - tmp_path = path.with_name( - path.stem + '___' + str(uuid.uuid4()).replace('-', '') + path.suffix - ) - shutil.copyfile(path, tmp_path) - atexit.register(lambda: tmp_path.unlink()) - return tmp_path - - -def get_python(): - python = Path('python3.9') - return str(python) if python.exists() else 'python' - -def get_catboost_config(real_data_path, is_cv=False): - ds_name = Path(real_data_path).name - C = load_json(f'tuned_models/catboost/{ds_name}_cv.json') - return C \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.pipeline.py b/src/synthcity/plugins/core/models/tabular_ddpm/.pipeline.py deleted file mode 100644 index f6855f6b..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.pipeline.py +++ /dev/null @@ -1,80 +0,0 @@ -import tomli -import shutil -import os -import argparse -from train import train -from sample import sample -import pandas as pd -import matplotlib.pyplot as plt -import zero -import lib -import torch - -def load_config(path) : - with open(path, 'rb') as f: - return tomli.load(f) - -def save_file(parent_dir, config_path): - try: - dst = os.path.join(parent_dir) - os.makedirs(os.path.dirname(dst), exist_ok=True) - shutil.copyfile(os.path.abspath(config_path), dst) - except shutil.SameFileError: - pass - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--config', metavar='FILE') - parser.add_argument('--train', action='store_true', default=False) - parser.add_argument('--sample', action='store_true', default=False) - parser.add_argument('--eval', action='store_true', default=False) - parser.add_argument('--change_val', action='store_true', default=False) - - args = parser.parse_args() - raw_config = lib.load_config(args.config) - if 'device' in raw_config: - device = torch.device(raw_config['device']) - else: - device = torch.device('cuda:1') - - timer = zero.Timer() - timer.run() - save_file(os.path.join(raw_config['parent_dir'], 'config.toml'), args.config) - - if args.train: - train( - **raw_config['train']['main'], - **raw_config['diffusion_params'], - parent_dir=raw_config['parent_dir'], - real_data_path=raw_config['real_data_path'], - model_type=raw_config['model_type'], - model_params=raw_config['model_params'], - T_dict=raw_config['train']['T'], - num_numerical_features=raw_config['num_numerical_features'], - device=device, - change_val=args.change_val - ) - if args.sample: - sample( - num_samples=raw_config['sample']['num_samples'], - batch_size=raw_config['sample']['batch_size'], - disbalance=raw_config['sample'].get('disbalance', None), - **raw_config['diffusion_params'], - parent_dir=raw_config['parent_dir'], - real_data_path=raw_config['real_data_path'], - model_path=os.path.join(raw_config['parent_dir'], 'model.pt'), - model_type=raw_config['model_type'], - model_params=raw_config['model_params'], - T_dict=raw_config['train']['T'], - num_numerical_features=raw_config['num_numerical_features'], - device=device, - seed=raw_config['sample'].get('seed', 0), - change_val=args.change_val - ) - - save_file(os.path.join(raw_config['parent_dir'], 'info.json'), os.path.join(raw_config['real_data_path'], 'info.json')) - - print(f'Elapsed time: {str(timer)}') - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.sample.py b/src/synthcity/plugins/core/models/tabular_ddpm/.sample.py deleted file mode 100644 index abc68162..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.sample.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch -import numpy as np -import zero -import os -from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion -from .utils import FoundNANsError -from utils_train import get_model, make_dataset -from .lib import round_columns -import lib - -def to_good_ohe(ohe, X): - indices = np.cumsum([0] + ohe._n_features_outs) - Xres = [] - for i in range(1, len(indices)): - x_ = np.max(X[:, indices[i - 1]:indices[i]], axis=1) - t = X[:, indices[i - 1]:indices[i]] - x_.reshape(-1, 1) - Xres.append(np.where(t >= 0, 1, 0)) - return np.hstack(Xres) - -def sample( - parent_dir, - real_data_path = 'data/higgs-small', - batch_size = 2000, - num_samples = 0, - model_type = 'mlp', - model_params = None, - model_path = None, - num_timesteps = 1000, - gaussian_loss_type = 'mse', - scheduler = 'cosine', - T_dict = None, - num_numerical_features = 0, - disbalance = None, - device = torch.device('cuda:1'), - seed = 0, - change_val = False -): - zero.improve_reproducibility(seed) - - T = lib.Transformations(**T_dict) - D = make_dataset( - real_data_path, - T, - num_classes=model_params['num_classes'], - is_y_cond=model_params['is_y_cond'], - change_val=change_val - ) - - K = np.array(D.get_category_sizes('train')) - if len(K) == 0 or T_dict['cat_encoding'] == 'one-hot': - K = np.array([0]) - - num_numerical_features_ = D.X_num['train'].shape[1] if D.X_num is not None else 0 - d_in = np.sum(K) + num_numerical_features_ - model_params['d_in'] = int(d_in) - model = get_model( - model_type, - model_params, - num_numerical_features_, - category_sizes=D.get_category_sizes('train') - ) - - model.load_state_dict( - torch.load(model_path, map_location="cpu") - ) - - diffusion = GaussianMultinomialDiffusion( - K, - num_numerical_features=num_numerical_features_, - denoise_fn=model, num_timesteps=num_timesteps, - gaussian_loss_type=gaussian_loss_type, scheduler=scheduler, device=device - ) - - diffusion.to(device) - diffusion.eval() - - _, empirical_class_dist = torch.unique(torch.from_numpy(D.y['train']), return_counts=True) - # empirical_class_dist = empirical_class_dist.float() + torch.tensor([-5000., 10000.]).float() - if disbalance == 'fix': - empirical_class_dist[0], empirical_class_dist[1] = empirical_class_dist[1], empirical_class_dist[0] - x_gen, y_gen = diffusion.sample_all(num_samples, batch_size, empirical_class_dist.float(), ddim=False) - - elif disbalance == 'fill': - ix_major = empirical_class_dist.argmax().item() - val_major = empirical_class_dist[ix_major].item() - x_gen, y_gen = [], [] - for i in range(empirical_class_dist.shape[0]): - if i == ix_major: - continue - distrib = torch.zeros_like(empirical_class_dist) - distrib[i] = 1 - num_samples = val_major - empirical_class_dist[i].item() - x_temp, y_temp = diffusion.sample_all(num_samples, batch_size, distrib.float(), ddim=False) - x_gen.append(x_temp) - y_gen.append(y_temp) - - x_gen = torch.cat(x_gen, dim=0) - y_gen = torch.cat(y_gen, dim=0) - - else: - x_gen, y_gen = diffusion.sample_all(num_samples, batch_size, empirical_class_dist.float(), ddim=False) - - - # try: - # except FoundNANsError as ex: - # print("Found NaNs during sampling!") - # loader = lib.prepare_fast_dataloader(D, 'train', 8) - # x_gen = next(loader)[0] - # y_gen = torch.multinomial( - # empirical_class_dist.float(), - # num_samples=8, - # replacement=True - # ) - X_gen, y_gen = x_gen.numpy(), y_gen.numpy() - - ### - # X_num_unnorm = X_gen[:, :num_numerical_features] - # lo = np.percentile(X_num_unnorm, 2.5, axis=0) - # hi = np.percentile(X_num_unnorm, 97.5, axis=0) - # idx = (lo < X_num_unnorm) & (hi > X_num_unnorm) - # X_gen = X_gen[np.all(idx, axis=1)] - # y_gen = y_gen[np.all(idx, axis=1)] - ### - - num_numerical_features = num_numerical_features + int(D.is_regression and not model_params["is_y_cond"]) - - X_num_ = X_gen - if num_numerical_features < X_gen.shape[1]: - np.save(os.path.join(parent_dir, 'X_cat_unnorm'), X_gen[:, num_numerical_features:]) - # _, _, cat_encoder = lib.cat_encode({'train': X_cat_real}, T_dict['cat_encoding'], y_real, T_dict['seed'], True) - if T_dict['cat_encoding'] == 'one-hot': - X_gen[:, num_numerical_features:] = to_good_ohe(D.cat_transform.steps[0][1], X_num_[:, num_numerical_features:]) - X_cat = D.cat_transform.inverse_transform(X_gen[:, num_numerical_features:]) - - if num_numerical_features_ != 0: - # _, normalize = lib.normalize({'train' : X_num_real}, T_dict['normalization'], T_dict['seed'], True) - np.save(os.path.join(parent_dir, 'X_num_unnorm'), X_gen[:, :num_numerical_features]) - X_num_ = D.num_transform.inverse_transform(X_gen[:, :num_numerical_features]) - X_num = X_num_[:, :num_numerical_features] - - X_num_real = np.load(os.path.join(real_data_path, "X_num_train.npy"), allow_pickle=True) - disc_cols = [] - for col in range(X_num_real.shape[1]): - uniq_vals = np.unique(X_num_real[:, col]) - if len(uniq_vals) <= 32 and ((uniq_vals - np.round(uniq_vals)) == 0).all(): - disc_cols.append(col) - print("Discrete cols:", disc_cols) - if model_params['num_classes'] == 0: - y_gen = X_num[:, 0] - X_num = X_num[:, 1:] - if len(disc_cols): - X_num = round_columns(X_num_real, X_num, disc_cols) - - if num_numerical_features != 0: - print("Num shape: ", X_num.shape) - np.save(os.path.join(parent_dir, 'X_num_train'), X_num) - if num_numerical_features < X_gen.shape[1]: - np.save(os.path.join(parent_dir, 'X_cat_train'), X_cat) - np.save(os.path.join(parent_dir, 'y_train'), y_gen) \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.train.py b/src/synthcity/plugins/core/models/tabular_ddpm/.train.py deleted file mode 100644 index 85cac744..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.train.py +++ /dev/null @@ -1,156 +0,0 @@ -from copy import deepcopy -import torch -import os -import numpy as np -import zero -from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion -from utils_train import get_model, make_dataset, update_ema -from . import lib -import pandas as pd - -class Trainer: - def __init__(self, diffusion, train_iter, lr, weight_decay, steps, device=torch.device('cuda:1')): - self.diffusion = diffusion - self.ema_model = deepcopy(self.diffusion._denoise_fn) - for param in self.ema_model.parameters(): - param.detach_() - - self.train_iter = train_iter - self.steps = steps - self.init_lr = lr - self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay) - self.device = device - self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) - self.log_every = 100 - self.print_every = 500 - self.ema_every = 1000 - - def _anneal_lr(self, step): - frac_done = step / self.steps - lr = self.init_lr * (1 - frac_done) - for param_group in self.optimizer.param_groups: - param_group["lr"] = lr - - def _run_step(self, x, out_dict): - x = x.to(self.device) - for k in out_dict: - out_dict[k] = out_dict[k].long().to(self.device) - self.optimizer.zero_grad() - loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict) - loss = loss_multi + loss_gauss - loss.backward() - self.optimizer.step() - - return loss_multi, loss_gauss - - def run_loop(self): - step = 0 - curr_loss_multi = 0.0 - curr_loss_gauss = 0.0 - - curr_count = 0 - while step < self.steps: - x, out_dict = next(self.train_iter) - out_dict = {'y': out_dict} - batch_loss_multi, batch_loss_gauss = self._run_step(x, out_dict) - - self._anneal_lr(step) - - curr_count += len(x) - curr_loss_multi += batch_loss_multi.item() * len(x) - curr_loss_gauss += batch_loss_gauss.item() * len(x) - - if (step + 1) % self.log_every == 0: - mloss = np.around(curr_loss_multi / curr_count, 4) - gloss = np.around(curr_loss_gauss / curr_count, 4) - if (step + 1) % self.print_every == 0: - print(f'Step {(step + 1)}/{self.steps} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') - self.loss_history.loc[len(self.loss_history)] =[step + 1, mloss, gloss, mloss + gloss] - curr_count = 0 - curr_loss_gauss = 0.0 - curr_loss_multi = 0.0 - - update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters()) - - step += 1 - -def train( - parent_dir, - real_data_path = 'data/higgs-small', - steps = 1000, - lr = 0.002, - weight_decay = 1e-4, - batch_size = 1024, - model_type = 'mlp', - model_params = None, - num_timesteps = 1000, - gaussian_loss_type = 'mse', - scheduler = 'cosine', - T_dict = None, - num_numerical_features = 0, - device = torch.device('cuda:1'), - seed = 0, - change_val = False -): - real_data_path = os.path.normpath(real_data_path) - parent_dir = os.path.normpath(parent_dir) - - zero.improve_reproducibility(seed) - - T = lib.Transformations(**T_dict) - - dataset = make_dataset( - real_data_path, - T, - num_classes=model_params['num_classes'], - is_y_cond=model_params['is_y_cond'], - change_val=change_val - ) - - K = np.array(dataset.get_category_sizes('train')) - if len(K) == 0 or T_dict['cat_encoding'] == 'one-hot': - K = np.array([0]) - print(K) - - num_numerical_features = dataset.X_num['train'].shape[1] if dataset.X_num is not None else 0 - d_in = np.sum(K) + num_numerical_features - model_params['d_in'] = d_in - print(d_in) - - print(model_params) - model = get_model( - model_type, - model_params, - num_numerical_features, - category_sizes=dataset.get_category_sizes('train') - ) - model.to(device) - - # train_loader = lib.prepare_beton_loader(dataset, split='train', batch_size=batch_size) - train_loader = lib.prepare_fast_dataloader(dataset, split='train', batch_size=batch_size) - - diffusion = GaussianMultinomialDiffusion( - num_classes=K, - num_numerical_features=num_numerical_features, - denoise_fn=model, - gaussian_loss_type=gaussian_loss_type, - num_timesteps=num_timesteps, - scheduler=scheduler, - device=device - ) - diffusion.to(device) - diffusion.train() - - trainer = Trainer( - diffusion, - train_loader, - lr=lr, - weight_decay=weight_decay, - steps=steps, - device=device - ) - trainer.run_loop() - - trainer.loss_history.to_csv(os.path.join(parent_dir, 'loss.csv'), index=False) - torch.save(diffusion._denoise_fn.state_dict(), os.path.join(parent_dir, 'model.pt')) - torch.save(trainer.ema_model.state_dict(), os.path.join(parent_dir, 'model_ema.pt')) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.tune.py b/src/synthcity/plugins/core/models/tabular_ddpm/.tune.py deleted file mode 100644 index 5a95dc23..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.tune.py +++ /dev/null @@ -1,127 +0,0 @@ -import subprocess -import lib -import os -import optuna -from copy import deepcopy -import shutil -import argparse -from pathlib import Path - -parser = argparse.ArgumentParser() -parser.add_argument('ds_name', type=str) -parser.add_argument('train_size', type=int) -parser.add_argument('eval_type', type=str) -parser.add_argument('eval_model', type=str) -parser.add_argument('prefix', type=str) -parser.add_argument('--eval_seeds', action='store_true', default=False) - -args = parser.parse_args() -train_size = args.train_size -ds_name = args.ds_name -eval_type = args.eval_type -assert eval_type in ('merged', 'synthetic') -prefix = str(args.prefix) - -pipeline = f'scripts/pipeline.py' -base_config_path = f'exp/{ds_name}/config.toml' -parent_path = Path(f'exp/{ds_name}/') -exps_path = Path(f'exp/{ds_name}/many-exps/') # temporary dir. maybe will be replaced with tempdiвdr -eval_seeds = f'scripts/eval_seeds.py' - -os.makedirs(exps_path, exist_ok=True) - -def _suggest_mlp_layers(trial): - def suggest_dim(name): - t = trial.suggest_int(name, d_min, d_max) - return 2 ** t - min_n_layers, max_n_layers, d_min, d_max = 1, 4, 7, 10 - n_layers = 2 * trial.suggest_int('n_layers', min_n_layers, max_n_layers) - d_first = [suggest_dim('d_first')] if n_layers else [] - d_middle = ( - [suggest_dim('d_middle')] * (n_layers - 2) - if n_layers > 2 - else [] - ) - d_last = [suggest_dim('d_last')] if n_layers > 1 else [] - d_layers = d_first + d_middle + d_last - return d_layers - -def objective(trial): - - lr = trial.suggest_loguniform('lr', 0.00001, 0.003) - d_layers = _suggest_mlp_layers(trial) - weight_decay = 0.0 - batch_size = trial.suggest_categorical('batch_size', [256, 4096]) - steps = trial.suggest_categorical('steps', [5000, 20000, 30000]) - # steps = trial.suggest_categorical('steps', [500]) # for debug - gaussian_loss_type = 'mse' - # scheduler = trial.suggest_categorical('scheduler', ['cosine', 'linear']) - num_timesteps = trial.suggest_categorical('num_timesteps', [100, 1000]) - num_samples = int(train_size * (2 ** trial.suggest_int('num_samples', -2, 1))) - - base_config = lib.load_config(base_config_path) - - base_config['train']['main']['lr'] = lr - base_config['train']['main']['steps'] = steps - base_config['train']['main']['batch_size'] = batch_size - base_config['train']['main']['weight_decay'] = weight_decay - base_config['model_params']['rtdl_params']['d_layers'] = d_layers - base_config['eval']['type']['eval_type'] = eval_type - base_config['sample']['num_samples'] = num_samples - base_config['diffusion_params']['gaussian_loss_type'] = gaussian_loss_type - base_config['diffusion_params']['num_timesteps'] = num_timesteps - # base_config['diffusion_params']['scheduler'] = scheduler - - base_config['parent_dir'] = str(exps_path / f"{trial.number}") - base_config['eval']['type']['eval_model'] = args.eval_model - if args.eval_model == "mlp": - base_config['eval']['T']['normalization'] = "quantile" - base_config['eval']['T']['cat_encoding'] = "one-hot" - - trial.set_user_attr("config", base_config) - - lib.dump_config(base_config, exps_path / 'config.toml') - - subprocess.run(['python3.9', f'{pipeline}', '--config', f'{exps_path / "config.toml"}', '--train', '--change_val'], check=True) - - n_datasets = 5 - score = 0.0 - - for sample_seed in range(n_datasets): - base_config['sample']['seed'] = sample_seed - lib.dump_config(base_config, exps_path / 'config.toml') - - subprocess.run(['python3.9', f'{pipeline}', '--config', f'{exps_path / "config.toml"}', '--sample', '--eval', '--change_val'], check=True) - - report_path = str(Path(base_config['parent_dir']) / f'results_{args.eval_model}.json') - report = lib.load_json(report_path) - - if 'r2' in report['metrics']['val']: - score += report['metrics']['val']['r2'] - else: - score += report['metrics']['val']['macro avg']['f1-score'] - - shutil.rmtree(exps_path / f"{trial.number}") - - return score / n_datasets - -study = optuna.create_study( - direction='maximize', - sampler=optuna.samplers.TPESampler(seed=0), -) - -study.optimize(objective, n_trials=50, show_progress_bar=True) - -best_config_path = parent_path / f'{prefix}_best/config.toml' -best_config = study.best_trial.user_attrs['config'] -best_config["parent_dir"] = str(parent_path / f'{prefix}_best/') - -os.makedirs(parent_path / f'{prefix}_best', exist_ok=True) -lib.dump_config(best_config, best_config_path) -lib.dump_json(optuna.importance.get_param_importances(study), parent_path / f'{prefix}_best/importance.json') - -subprocess.run(['python3.9', f'{pipeline}', '--config', f'{best_config_path}', '--train', '--sample'], check=True) - -if args.eval_seeds: - best_exp = str(parent_path / f'{prefix}_best/config.toml') - subprocess.run(['python3.9', f'{eval_seeds}', '--config', f'{best_exp}', '10', "ddpm", eval_type, args.eval_model, '5'], check=True) \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/.utils_train.py b/src/synthcity/plugins/core/models/tabular_ddpm/.utils_train.py deleted file mode 100644 index 3062b15d..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/.utils_train.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import os -import lib -from .modules import MLPDiffusion, ResNetDiffusion - -def get_model( - model_name, - model_params, - n_num_features, - category_sizes -): - if model_name == 'mlp': - model = MLPDiffusion(**model_params) - elif model_name == 'resnet': - model = ResNetDiffusion(**model_params) - else: - raise "Unknown model!" - return model - -def update_ema(target_params, source_params, rate=0.999): - """ - Update target parameters to be closer to those of source parameters using - an exponential moving average. - :param target_params: the target parameter sequence. - :param source_params: the source parameter sequence. - :param rate: the EMA rate (closer to 1 means slower). - """ - for targ, src in zip(target_params, source_params): - targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) - -def concat_y_to_X(X, y): - if X is None: - return y.reshape(-1, 1) - return np.concatenate([y.reshape(-1, 1), X], axis=1) - -def make_dataset( - data_path: str, - T: lib.Transformations, - num_classes: int, - is_y_cond: bool, - change_val: bool -): - # classification - if num_classes > 0: - X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) or not is_y_cond else None - X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) else None - y = {} - - for split in ['train', 'val', 'test']: - X_num_t, X_cat_t, y_t = lib.read_pure_data(data_path, split) - if X_num is not None: - X_num[split] = X_num_t - if not is_y_cond: - X_cat_t = concat_y_to_X(X_cat_t, y_t) - if X_cat is not None: - X_cat[split] = X_cat_t - y[split] = y_t - else: - # regression - X_cat = {} if os.path.exists(os.path.join(data_path, 'X_cat_train.npy')) else None - X_num = {} if os.path.exists(os.path.join(data_path, 'X_num_train.npy')) or not is_y_cond else None - y = {} - - for split in ['train', 'val', 'test']: - X_num_t, X_cat_t, y_t = lib.read_pure_data(data_path, split) - if not is_y_cond: - X_num_t = concat_y_to_X(X_num_t, y_t) - if X_num is not None: - X_num[split] = X_num_t - if X_cat is not None: - X_cat[split] = X_cat_t - y[split] = y_t - - info = lib.load_json(os.path.join(data_path, 'info.json')) - - D = lib.Dataset( - X_num, - X_cat, - y, - y_info={}, - task_type=lib.TaskType(info['task_type']), - n_classes=info.get('n_classes') - ) - - if change_val: - D = lib.change_val(D) - - return lib.transform_dataset(D, T, None) \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/README.md b/src/synthcity/plugins/core/models/tabular_ddpm/README.md deleted file mode 100644 index 3d418685..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# TabDDPM: Modelling Tabular Data with Diffusion Models - -Adapted from https://github.com/rotot0/tab-ddpm. diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index d4fa28e6..95d31db6 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -32,6 +32,7 @@ def __init__( gaussian_loss_type = 'mse', scheduler = 'cosine', device: Any = DEVICE, + verbose: int = 0, log_interval: int = 100, print_interval: int = 500, # model params @@ -45,7 +46,7 @@ def __init__( ) -> None: super().__init__() self.__dict__.update(locals()) - del self.self, self.kwargs + del self.self def _anneal_lr(self, step): frac_done = step / self.steps @@ -69,7 +70,7 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): n_labels = cond.nunique() else: n_labels = 0 - + cat_cols = discrete_columns(X, return_counts=True) ini_cols = X.columns cat_cols, cat_counts = zip(*cat_cols) @@ -86,9 +87,9 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): dim_t = self.dim_label_emb ) - tensors = [X] if cond is None else [X, cond] - tensors = [torch.tensor(t.values, dtype=torch.float32, device=self.device) for t in tensors] - self.dataloader = TensorDataLoader(tensors, batch_size=self.batch_size) + tensors = [torch.tensor(t.values, dtype=torch.float32, device=self.device) + for t in ([X] if cond is None else [X, cond])] + self.dataloader = TensorDataLoader(*tensors, batch_size=self.batch_size) self.diffusion = GaussianMultinomialDiffusion( model_type=self.model_type, @@ -98,7 +99,8 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): gaussian_loss_type=self.gaussian_loss_type, num_timesteps=self.num_timesteps, scheduler=self.scheduler, - device=self.device + device=self.device, + verbose=self.verbose, ).to(self.device) self.ema_model = deepcopy(self.diffusion.denoise_fn) @@ -110,6 +112,10 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) + if self.verbose: + print("Starting training") + print(self) + for step, (x, y) in enumerate(self.dataloader): curr_loss_multi = 0.0 curr_loss_gauss = 0.0 @@ -131,7 +137,7 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): if (step + 1) % self.log_interval == 0: mloss = np.around(curr_loss_multi / curr_count, 4) gloss = np.around(curr_loss_gauss / curr_count, 4) - if (step + 1) % self.print_interval == 0: + if self.verbose and (step + 1) % self.print_interval == 0: print(f'Step {(step + 1)}/{self.n_iter} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') self.loss_history.loc[len(self.loss_history)] = [ step + 1, mloss, gloss, mloss + gloss] diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index dfcbd00a..580de6fe 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -74,22 +74,28 @@ def __init__( multinomial_loss_type='vb_stochastic', parametrization='x0', scheduler='cosine', - device=torch.device('cpu') + device=torch.device('cpu'), + verbose=0 ): super(GaussianMultinomialDiffusion, self).__init__() assert multinomial_loss_type in ('vb_stochastic', 'vb_all') assert parametrization in ('x0', 'direct') + + if verbose: + self.print = print + else: + self.print = lambda *args, **kwargs: None if multinomial_loss_type == 'vb_all': - print('Computing the loss using the bound on _all_ timesteps.' + self.print('Computing the loss using the bound on _all_ timesteps.' ' This is expensive both in terms of memory and computation.') self.num_numerics = num_numerical_features self.num_classes = num_categorical_features self.num_classes_expanded = torch.from_numpy( - np.concatenate([num_categorical_features[i].repeat(num_categorical_features[i]) for i in range(len(num_categorical_features))]) - ).to(device) + np.concatenate([np.repeat(k, k) for k in num_categorical_features], + dtype=np.float32)).to(device) self.dim_input = self.num_numerics + sum(self.num_classes) self.slices_for_classes = [np.arange(self.num_classes[0])] @@ -409,7 +415,7 @@ def q_posterior(self, log_x_start, log_x_t, t): # EV_log_qxt_x0 = self.q_pred(log_x_start, t) - # print('sum exp', EV_log_qxt_x0.exp().sum(1).mean()) + # self.print('sum exp', EV_log_qxt_x0.exp().sum(1).mean()) # assert False # log_qxt_x0 = (log_x_t.exp() * EV_log_qxt_x0).sum(dim=1) @@ -800,7 +806,7 @@ def gaussian_ddim_sample( b = x.shape[0] device = x.device for t in reversed(range(T)): - print(f'Sample timestep {t:4d}', end='\r') + self.print(f'Sample timestep {t:4d}', end='\r') t_array = (torch.ones(b, device=device) * t).long() out_num = self.denoise_fn(x, t_array, y=cond) x = self.gaussian_ddim_step( @@ -808,7 +814,7 @@ def gaussian_ddim_sample( x, t_array ) - print() + self.print() return x @@ -854,7 +860,7 @@ def gaussian_ddim_reverse_sample( b = x.shape[0] device = x.device for t in range(T): - print(f'Reverse timestep {t:4d}', end='\r') + self.print(f'Reverse timestep {t:4d}', end='\r') t_array = (torch.ones(b, device=device) * t).long() out_num = self.denoise_fn(x, t_array, y=cond) x = self.gaussian_ddim_reverse_step( @@ -863,7 +869,7 @@ def gaussian_ddim_reverse_sample( t_array, eta=0.0 ) - print() + self.print() return x @@ -923,7 +929,7 @@ def sample_ddim(self, num_samples, cond=None): # ) # out_dict = {'y': y.long().to(device)} for i in reversed(range(0, self.num_timesteps)): - print(f'Sample timestep {i:4d}', end='\r') + self.print(f'Sample timestep {i:4d}', end='\r') t = torch.full((b,), i, device=device, dtype=torch.long) model_out = self.denoise_fn( torch.cat([z_norm, log_z], dim=1).float(), @@ -935,7 +941,7 @@ def sample_ddim(self, num_samples, cond=None): if has_cat: log_z = self.multinomial_ddim_step(model_out_cat, log_z, t) - print() + self.print() z_ohe = torch.exp(log_z).round() z_cat = log_z if has_cat: @@ -962,7 +968,7 @@ def sample(self, num_samples, cond=None): # ) # out_dict = {'y': y.long().to(device)} for i in reversed(range(0, self.num_timesteps)): - print(f'Sample timestep {i:4d}', end='\r') + self.print(f'Sample timestep {i:4d}', end='\r') t = torch.full((b,), i, device=device, dtype=torch.long) model_out = self.denoise_fn( torch.cat([z_norm, log_z], dim=1).float(), @@ -974,7 +980,7 @@ def sample(self, num_samples, cond=None): if has_cat: log_z = self.p_sample(model_out_cat, log_z, t=t) - print() + self.print() z_ohe = torch.exp(log_z).round() z_cat = log_z if has_cat: @@ -984,7 +990,7 @@ def sample(self, num_samples, cond=None): def sample_all(self, num_samples, cond=None, max_batch_size=2000, ddim=False): if ddim: - print('Sample using DDIM.') + self.print('Sample using DDIM.') sample_fn = self.sample_ddim else: sample_fn = self.sample diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/requirements.txt b/src/synthcity/plugins/core/models/tabular_ddpm/requirements.txt deleted file mode 100644 index acc088c4..00000000 --- a/src/synthcity/plugins/core/models/tabular_ddpm/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -category-encoders==2.3.0 -dython==0.5.1 -icecream==2.1.2 -libzero==0.0.8 -numpy==1.21.4 -optuna==2.10.1 -pandas==1.3.4 -pyarrow==6.0.0 -rtdl==0.0.9 -scikit-learn==1.0.2 -scipy==1.7.2 -skorch==0.11.0 -tomli-w==0.4.0 -tomli==1.2.2 -tqdm==4.62.3 diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index f2f07f80..e2fde2d1 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -68,6 +68,7 @@ def __init__( gaussian_loss_type = 'mse', scheduler = 'cosine', device: Any = DEVICE, + verbose: int = 0, log_interval: int = 100, print_interval: int = 500, # model params @@ -99,8 +100,8 @@ def __init__( self.is_classification = is_classification rtdl_params = dict( - d_layers = [self.dim_hidden] * self.num_layers, - dropout = self.dropout + d_layers = [dim_hidden] * num_layers, + dropout = dropout ) self.model = TabDDPM( n_iter=n_iter, @@ -111,6 +112,7 @@ def __init__( gaussian_loss_type=gaussian_loss_type, scheduler=scheduler, device=device, + verbose=verbose, log_interval=log_interval, print_interval=print_interval, model_type=model_type, @@ -161,14 +163,14 @@ def _fit(self, data: DataLoader, cond: pd.Series = None, **kwargs) -> "TabDDPMPl assert cond is None _, cond = data.unpack() self._labels, self._cond_dist = np.unique(cond, return_counts=True) - self._cond_dist /= self._cond_dist.sum() - - if cond is not None: - cond = pd.Series(cond, index=data.index) + self._cond_dist = self._cond_dist / self._cond_dist.sum() # NOTE: should we include the target column in `data`? data = data.dataframe() + if cond is not None: + cond = pd.Series(cond, index=data.index) + # self.encoder = TabularEncoder().fit(X) self.model.fit(data, cond, **kwargs) diff --git a/src/temp.py b/src/temp.py new file mode 100644 index 00000000..73e3390c --- /dev/null +++ b/src/temp.py @@ -0,0 +1,15 @@ +import numpy as np +import pandas as pd + +from my_utils.debug import loadDebugger +from synthcity.plugins import Plugins +from sklearn.datasets import load_iris +from synthcity.plugins.core.dataloader import GenericDataLoader + +# loadDebugger() +X, y = load_iris(as_frame = True, return_X_y = True) +X = GenericDataLoader(X.assign(target = y), target_column="target") +plugin = Plugins().get("ddpm", n_iter=3, is_classification=True, verbose=1) +plugin.fit(X) +X_syn = plugin.generate(50) +print(X_syn) diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index a398c000..1d734753 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -16,6 +16,7 @@ plugin_args = dict( n_iter=100, is_classification=True, + verbose=1, # rtdl_params=dict( # d_layers=[256, 256], # dropout=0.0 From fc9cee0403fc76f779b704585191d6f147efb2f3 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 7 Mar 2023 19:43:08 +0100 Subject: [PATCH 07/95] update TensorDataLoader and training loop --- .../core/models/tabular_ddpm/__init__.py | 71 ++++++++++--------- .../plugins/core/models/tabular_ddpm/utils.py | 16 ++--- 2 files changed, 43 insertions(+), 44 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 95d31db6..ea288ec8 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -24,7 +24,7 @@ class TabDDPM(nn.Module): @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__( self, - n_iter = 10000, + n_iter = 100, lr = 0.002, weight_decay = 1e-4, batch_size = 1024, @@ -48,8 +48,8 @@ def __init__( self.__dict__.update(locals()) del self.self - def _anneal_lr(self, step): - frac_done = step / self.steps + def _anneal_lr(self, epoch): + frac_done = epoch / self.n_iter lr = self.lr * (1 - frac_done) for param_group in self.optimizer.param_groups: param_group["lr"] = lr @@ -116,40 +116,41 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): print("Starting training") print(self) - for step, (x, y) in enumerate(self.dataloader): - curr_loss_multi = 0.0 - curr_loss_gauss = 0.0 - curr_count = 0 + steps = 0 + curr_loss_multi = 0.0 + curr_loss_gauss = 0.0 + curr_count = 0 + + for epoch in range(self.n_iter): self.diffusion.train() - self.optimizer.zero_grad() - loss_multi, loss_gauss = self.diffusion.mixed_loss(x, dict(y=y)) - loss = loss_multi + loss_gauss - loss.backward() - self.optimizer.step() - - self._anneal_lr(step) - - curr_count += len(x) - curr_loss_multi += loss_multi.item() * len(x) - curr_loss_gauss += loss_gauss.item() * len(x) - - if (step + 1) % self.log_interval == 0: - mloss = np.around(curr_loss_multi / curr_count, 4) - gloss = np.around(curr_loss_gauss / curr_count, 4) - if self.verbose and (step + 1) % self.print_interval == 0: - print(f'Step {(step + 1)}/{self.n_iter} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') - self.loss_history.loc[len(self.loss_history)] = [ - step + 1, mloss, gloss, mloss + gloss] - curr_count = 0 - curr_loss_gauss = 0.0 - curr_loss_multi = 0.0 - - self._update_ema(self.ema_model.parameters(), self.model.parameters()) - - if step == self.n_iter - 1: - break - + for x, y in self.dataloader: + self.optimizer.zero_grad() + loss_multi, loss_gauss = self.diffusion.mixed_loss(x, y) + loss = loss_multi + loss_gauss + loss.backward() + self.optimizer.step() + + self._anneal_lr(epoch + 1) + + curr_count += len(x) + curr_loss_multi += loss_multi.item() * len(x) + curr_loss_gauss += loss_gauss.item() * len(x) + + steps += 1 + if steps % self.log_interval == 0: + mloss = np.around(curr_loss_multi / curr_count, 4) + gloss = np.around(curr_loss_gauss / curr_count, 4) + if self.verbose and steps % self.print_interval == 0: + print(f'Step {steps}: MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') + self.loss_history.loc[len(self.loss_history)] = \ + [steps, mloss, gloss, mloss + gloss] + curr_count = 0 + curr_loss_gauss = 0.0 + curr_loss_multi = 0.0 + + self._update_ema(self.ema_model.parameters(), self.model.parameters()) + return self def generate(self, count: int, cond=None): diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index a0021a68..ff92f275 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -198,15 +198,13 @@ def __init__(self, *tensors, batch_size=32, shuffle=False): self.shuffle = shuffle def __iter__(self): - i = 0 idx = np.arange(self.dataset_len) if self.shuffle: np.random.shuffle(idx) - while True: - j = i + self.batch_size - s = slice(i, j) - if j > self.dataset_len: - s = list(range(i, self.dataset_len)) + list(range(0, j - self.dataset_len)) - if self.shuffle: - np.random.shuffle(idx) - yield tuple(t[idx[s]] for t in self.tensors) + for i in range(0, self.dataset_len, self.batch_size): + s = idx[i:i+self.batch_size] + yield tuple(t[s] for t in self.tensors) + + def __len__(self): + return len(range(0, self.dataset_len, self.batch_size)) + \ No newline at end of file From d8b57addee3d08eff70f1eb98256664547fb19ee Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 7 Mar 2023 20:20:01 +0100 Subject: [PATCH 08/95] clear bugs --- .../plugins/core/models/tabular_ddpm/__init__.py | 12 +++++++++--- .../gaussian_multinomial_diffsuion.py | 7 ++++--- src/synthcity/plugins/generic/plugin_ddpm.py | 2 +- src/temp.py | 15 --------------- tests/plugins/generic/test_ddpm.py | 5 +++++ 5 files changed, 19 insertions(+), 22 deletions(-) delete mode 100644 src/temp.py diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index ea288ec8..daa8c8eb 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -87,8 +87,9 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): dim_t = self.dim_label_emb ) - tensors = [torch.tensor(t.values, dtype=torch.float32, device=self.device) - for t in ([X] if cond is None else [X, cond])] + tensors = [torch.tensor(X.values, dtype=torch.float32, device=self.device)] + if cond is not None: + tensors.append(torch.tensor(cond.values, dtype=torch.long, device=self.device)) self.dataloader = TensorDataLoader(*tensors, batch_size=self.batch_size) self.diffusion = GaussianMultinomialDiffusion( @@ -149,12 +150,17 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): curr_loss_gauss = 0.0 curr_loss_multi = 0.0 - self._update_ema(self.ema_model.parameters(), self.model.parameters()) + self._update_ema(self.ema_model.parameters(), self.diffusion.parameters()) return self def generate(self, count: int, cond=None): self.diffusion.eval() + if cond is not None: + cond = torch.tensor(cond, dtype=torch.long, device=self.device) sample = self.diffusion.sample_all(count, cond).detach().cpu().numpy() sample = sample[:, self._col_perm] + if self.verbose: + print("Generated sample") + print(sample) return sample diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 580de6fe..ba7b1e55 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -92,7 +92,7 @@ def __init__( ' This is expensive both in terms of memory and computation.') self.num_numerics = num_numerical_features - self.num_classes = num_categorical_features + self.num_classes = np.asarray(num_categorical_features) self.num_classes_expanded = torch.from_numpy( np.concatenate([np.repeat(k, k) for k in num_categorical_features], dtype=np.float32)).to(device) @@ -102,7 +102,7 @@ def __init__( offsets = np.cumsum(self.num_classes) for i in range(1, len(offsets)): self.slices_for_classes.append(np.arange(offsets[i - 1], offsets[i])) - self.offsets = torch.from_numpy(np.append([0], offsets)).to(device) + self.offsets = torch.from_numpy(np.append([0], offsets)).to(device).long() if model_params is None: model_params = dict( @@ -426,7 +426,8 @@ def q_posterior(self, log_x_start, log_x_t, t): num_axes = (1,) * (len(log_x_start.size()) - 1) t_broadcast = t.to(log_x_start.device).view(-1, *num_axes) * torch.ones_like(log_x_start) - log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0.to(torch.float32)) + log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, + log_EV_qxtmin_x0.to(torch.float32)) # unnormed_logprobs = log_EV_qxtmin_x0 + # log q_pred_one_timestep(x_t, t) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index e2fde2d1..3a3da116 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -174,7 +174,7 @@ def _fit(self, data: DataLoader, cond: pd.Series = None, **kwargs) -> "TabDDPMPl # self.encoder = TabularEncoder().fit(X) self.model.fit(data, cond, **kwargs) - + def _generate(self, count: int, syn_schema: Schema, cond=None, **kwargs: Any) -> DataLoader: if self.is_classification and cond is None: # randomly generate labels following the distribution of the training data diff --git a/src/temp.py b/src/temp.py deleted file mode 100644 index 73e3390c..00000000 --- a/src/temp.py +++ /dev/null @@ -1,15 +0,0 @@ -import numpy as np -import pandas as pd - -from my_utils.debug import loadDebugger -from synthcity.plugins import Plugins -from sklearn.datasets import load_iris -from synthcity.plugins.core.dataloader import GenericDataLoader - -# loadDebugger() -X, y = load_iris(as_frame = True, return_X_y = True) -X = GenericDataLoader(X.assign(target = y), target_column="target") -plugin = Plugins().get("ddpm", n_iter=3, is_classification=True, verbose=1) -plugin.fit(X) -X_syn = plugin.generate(50) -print(X_syn) diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index 1d734753..e11d8af9 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -16,7 +16,12 @@ plugin_args = dict( n_iter=100, is_classification=True, + n_iter=1000, + batch_size=200, + num_timesteps=500, verbose=1, + log_interval=10, + print_interval=50 # rtdl_params=dict( # d_layers=[256, 256], # dropout=0.0 From 92dcc328456e9c3230cd565497d14dc72c7e6fb8 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 7 Mar 2023 20:35:58 +0100 Subject: [PATCH 09/95] debug for regression tasks --- .../plugins/core/models/tabular_ddpm/__init__.py | 12 ++++++------ tests/plugins/generic/test_ddpm.py | 9 ++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index daa8c8eb..495302fa 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -87,9 +87,9 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): dim_t = self.dim_label_emb ) - tensors = [torch.tensor(X.values, dtype=torch.float32, device=self.device)] - if cond is not None: - tensors.append(torch.tensor(cond.values, dtype=torch.long, device=self.device)) + tensors = [torch.tensor(X.values, dtype=torch.float32, device=self.device), + np.repeat(None, len(X)) if cond is None else + torch.tensor(cond.values, dtype=torch.long, device=self.device)] self.dataloader = TensorDataLoader(*tensors, batch_size=self.batch_size) self.diffusion = GaussianMultinomialDiffusion( @@ -113,9 +113,9 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) - if self.verbose: - print("Starting training") - print(self) + # if self.verbose: + # print("Starting training") + # print(self) steps = 0 curr_loss_multi = 0.0 diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index e11d8af9..7f56077a 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -14,14 +14,13 @@ plugin_name = "ddpm" plugin_args = dict( - n_iter=100, - is_classification=True, - n_iter=1000, + n_iter=1000, + # is_classification=True, batch_size=200, num_timesteps=500, verbose=1, log_interval=10, - print_interval=50 + print_interval=100 # rtdl_params=dict( # d_layers=[256, 256], # dropout=0.0 @@ -129,7 +128,7 @@ def test_eval_performance_ddpm(compress_dataset: bool) -> None: X = GenericDataLoader(Xraw) for _ in range(2): - test_plugin = plugin(n_iter=5000, compress_dataset=compress_dataset) + test_plugin = plugin(**plugin_args, compress_dataset=compress_dataset) evaluator = PerformanceEvaluatorXGB() test_plugin.fit(X) From 0b9d0e3d6dc05fb47c3fb01439b55ec65db18767 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 7 Mar 2023 20:49:00 +0100 Subject: [PATCH 10/95] debug for regression tasks; ALL TESTS PASSED --- .../core/models/tabular_ddpm/__init__.py | 31 +++++++++++-------- tests/plugins/generic/test_ddpm.py | 9 +++--- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index daa8c8eb..6c9947b3 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -72,13 +72,18 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): n_labels = 0 cat_cols = discrete_columns(X, return_counts=True) - ini_cols = X.columns - cat_cols, cat_counts = zip(*cat_cols) - # reorder the columns so that the categorical ones go to the end - X = X[np.hstack([X.columns[~X.keys().isin(cat_cols)], cat_cols])] - cur_cols = X.columns - # find the permutation from the reordered columns to the original ones - self._col_perm = np.argsort(cur_cols)[np.argsort(np.argsort(ini_cols))] + + if cat_cols: + ini_cols = X.columns + cat_cols, cat_counts = zip(*cat_cols) + # reorder the columns so that the categorical ones go to the end + X = X[np.hstack([X.columns[~X.keys().isin(cat_cols)], cat_cols])] + cur_cols = X.columns + # find the permutation from the reordered columns to the original ones + self._col_perm = np.argsort(cur_cols)[np.argsort(np.argsort(ini_cols))] + else: + cat_counts = [0] + self._col_perm = np.arange(X.shape[1]) model_params = dict( num_classes=n_labels, @@ -87,9 +92,9 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): dim_t = self.dim_label_emb ) - tensors = [torch.tensor(X.values, dtype=torch.float32, device=self.device)] - if cond is not None: - tensors.append(torch.tensor(cond.values, dtype=torch.long, device=self.device)) + tensors = [torch.tensor(X.values, dtype=torch.float32, device=self.device), + np.repeat(None, len(X)) if cond is None else + torch.tensor(cond.values, dtype=torch.long, device=self.device)] self.dataloader = TensorDataLoader(*tensors, batch_size=self.batch_size) self.diffusion = GaussianMultinomialDiffusion( @@ -113,9 +118,9 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) - if self.verbose: - print("Starting training") - print(self) + # if self.verbose: + # print("Starting training") + # print(self) steps = 0 curr_loss_multi = 0.0 diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index e11d8af9..7f56077a 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -14,14 +14,13 @@ plugin_name = "ddpm" plugin_args = dict( - n_iter=100, - is_classification=True, - n_iter=1000, + n_iter=1000, + # is_classification=True, batch_size=200, num_timesteps=500, verbose=1, log_interval=10, - print_interval=50 + print_interval=100 # rtdl_params=dict( # d_layers=[256, 256], # dropout=0.0 @@ -129,7 +128,7 @@ def test_eval_performance_ddpm(compress_dataset: bool) -> None: X = GenericDataLoader(Xraw) for _ in range(2): - test_plugin = plugin(n_iter=5000, compress_dataset=compress_dataset) + test_plugin = plugin(**plugin_args, compress_dataset=compress_dataset) evaluator = PerformanceEvaluatorXGB() test_plugin.fit(X) From bb9822954f960652e277db117d734bf3b97cec22 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 7 Mar 2023 21:09:28 +0100 Subject: [PATCH 11/95] remove the official repo of TabDDPM --- third-party/tab-ddpm | 1 - 1 file changed, 1 deletion(-) delete mode 160000 third-party/tab-ddpm diff --git a/third-party/tab-ddpm b/third-party/tab-ddpm deleted file mode 160000 index 41f2415a..00000000 --- a/third-party/tab-ddpm +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 41f2415a378f1e8e8f4f5c3b8736521c0d47cf22 From b4486a48caf102fda67495816b6b314dd03ebe62 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Wed, 8 Mar 2023 10:57:56 +0100 Subject: [PATCH 12/95] passed all pre-commit checks --- docs/tutorials | 2 +- .../core/models/tabular_ddpm/__init__.py | 90 +-- .../gaussian_multinomial_diffsuion.py | 527 +++++++++--------- .../core/models/tabular_ddpm/modules.py | 80 +-- .../plugins/core/models/tabular_ddpm/utils.py | 62 ++- .../plugins/core/models/tabular_encoder.py | 2 +- src/synthcity/plugins/generic/plugin_ddpm.py | 68 +-- src/synthcity/utils/dataframe.py | 16 +- 8 files changed, 460 insertions(+), 387 deletions(-) diff --git a/docs/tutorials b/docs/tutorials index c6fce2d2..27afa3de 120000 --- a/docs/tutorials +++ b/docs/tutorials @@ -1 +1 @@ -../tutorials/ \ No newline at end of file +../tutorials/ diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 6c9947b3..98ac9619 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-untyped-calls +# flake8: noqa: F401 + # stdlib from copy import deepcopy from typing import Any, Optional, Union @@ -6,37 +9,36 @@ import numpy as np import pandas as pd import torch -from torch import nn from pydantic import validate_arguments +from torch import nn # synthcity absolute +from synthcity.metrics.weighted_metrics import WeightedMetrics from synthcity.utils.constants import DEVICE from synthcity.utils.dataframe import discrete_columns -from synthcity.metrics.weighted_metrics import WeightedMetrics -from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion # noqa -from .modules import MLPDiffusion, ResNetDiffusion # noqa +# synthcity relative +from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion from .utils import TensorDataLoader class TabDDPM(nn.Module): - @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__( self, - n_iter = 100, - lr = 0.002, - weight_decay = 1e-4, - batch_size = 1024, - num_timesteps = 1000, - gaussian_loss_type = 'mse', - scheduler = 'cosine', + n_iter: int = 1000, + lr: float = 0.002, + weight_decay: float = 1e-4, + batch_size: int = 1024, + num_timesteps: int = 1000, + gaussian_loss_type: str = "mse", + scheduler: str = "cosine", device: Any = DEVICE, verbose: int = 0, - log_interval: int = 100, - print_interval: int = 500, + log_interval: int = 10, + print_interval: int = 100, # model params - model_type = 'mlp', + model_type: str = "mlp", rtdl_params: Optional[dict] = None, # {'d_layers', 'dropout'} dim_label_emb: int = 128, # early stopping @@ -47,7 +49,7 @@ def __init__( super().__init__() self.__dict__.update(locals()) del self.self - + def _anneal_lr(self, epoch): frac_done = epoch / self.n_iter lr = self.lr * (1 - frac_done) @@ -65,14 +67,14 @@ def _update_ema(self, target_params, source_params, rate=0.999): for targ, src in zip(target_params, source_params): targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) - def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): + def fit(self, X: pd.DataFrame, cond: Any = None, **kwargs: Any) -> "TabDDPM": if cond is not None: n_labels = cond.nunique() else: n_labels = 0 cat_cols = discrete_columns(X, return_counts=True) - + if cat_cols: ini_cols = X.columns cat_cols, cat_counts = zip(*cat_cols) @@ -89,47 +91,51 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): num_classes=n_labels, is_y_cond=cond is not None, rtdl_params=self.rtdl_params, - dim_t = self.dim_label_emb + dim_t=self.dim_label_emb, ) - - tensors = [torch.tensor(X.values, dtype=torch.float32, device=self.device), - np.repeat(None, len(X)) if cond is None else - torch.tensor(cond.values, dtype=torch.long, device=self.device)] + + tensors = [ + torch.tensor(X.values, dtype=torch.float32, device=self.device), + np.repeat(None, len(X)) + if cond is None + else torch.tensor(cond.values, dtype=torch.long, device=self.device), + ] self.dataloader = TensorDataLoader(*tensors, batch_size=self.batch_size) self.diffusion = GaussianMultinomialDiffusion( model_type=self.model_type, model_params=model_params, num_categorical_features=cat_counts, - num_numerical_features=X.shape[1]-len(cat_cols), + num_numerical_features=X.shape[1] - len(cat_cols), gaussian_loss_type=self.gaussian_loss_type, num_timesteps=self.num_timesteps, scheduler=self.scheduler, device=self.device, verbose=self.verbose, ).to(self.device) - + self.ema_model = deepcopy(self.diffusion.denoise_fn) for param in self.ema_model.parameters(): param.detach_() self.optimizer = torch.optim.AdamW( - self.diffusion.parameters(), lr=self.lr, weight_decay=self.weight_decay) - - self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss']) - + self.diffusion.parameters(), lr=self.lr, weight_decay=self.weight_decay + ) + + self.loss_history = pd.DataFrame(columns=["step", "mloss", "gloss", "loss"]) + # if self.verbose: # print("Starting training") # print(self) - + steps = 0 curr_loss_multi = 0.0 curr_loss_gauss = 0.0 curr_count = 0 - + for epoch in range(self.n_iter): self.diffusion.train() - + for x, y in self.dataloader: self.optimizer.zero_grad() loss_multi, loss_gauss = self.diffusion.mixed_loss(x, y) @@ -148,18 +154,26 @@ def fit(self, X: pd.DataFrame, cond=None, **kwargs: Any): mloss = np.around(curr_loss_multi / curr_count, 4) gloss = np.around(curr_loss_gauss / curr_count, 4) if self.verbose and steps % self.print_interval == 0: - print(f'Step {steps}: MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}') - self.loss_history.loc[len(self.loss_history)] = \ - [steps, mloss, gloss, mloss + gloss] + print( + f"Step {steps}: MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}" + ) + self.loss_history.loc[len(self.loss_history)] = [ + steps, + mloss, + gloss, + mloss + gloss, + ] curr_count = 0 curr_loss_gauss = 0.0 curr_loss_multi = 0.0 - self._update_ema(self.ema_model.parameters(), self.diffusion.parameters()) - + self._update_ema( + self.ema_model.parameters(), self.diffusion.parameters() + ) + return self - def generate(self, count: int, cond=None): + def generate(self, count: int, cond: Any = None) -> np.ndarray: self.diffusion.eval() if cond is not None: cond = torch.tensor(cond, dtype=torch.long, device=self.device) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index ba7b1e55..7a2b358d 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -1,20 +1,24 @@ """ -Based on https://github.com/openai/guided-diffusion/blob/main/guided_diffusion -and https://github.com/ehoogeboom/multinomial_diffusion +Based on +- https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +- https://github.com/ehoogeboom/multinomial_diffusion +- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 """ +# mypy: disable-error-code=no-untyped-def +# flake8: noqa: F405 -import torch.nn.functional as F -import torch +# stdlib import math -import pandas as pd +# third party import numpy as np -from .utils import * +import torch +import torch.nn.functional as F + +# synthcity relative from .modules import MLPDiffusion, ResNetDiffusion +from .utils import * # noqa: F403 -""" -Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 -""" def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): """ @@ -63,41 +67,45 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): class GaussianMultinomialDiffusion(torch.nn.Module): def __init__( - self, - num_numerical_features, - num_categorical_features, - model_type='mlp', - model_params=None, - num_timesteps=1000, - gaussian_loss_type='mse', - gaussian_parametrization='eps', - multinomial_loss_type='vb_stochastic', - parametrization='x0', - scheduler='cosine', - device=torch.device('cpu'), - verbose=0 - ): + self, + num_numerical_features, + num_categorical_features, + model_type="mlp", + model_params=None, + num_timesteps=1000, + gaussian_loss_type="mse", + gaussian_parametrization="eps", + multinomial_loss_type="vb_stochastic", + parametrization="x0", + scheduler="cosine", + device=torch.device("cpu"), + verbose=0, + ): super(GaussianMultinomialDiffusion, self).__init__() - assert multinomial_loss_type in ('vb_stochastic', 'vb_all') - assert parametrization in ('x0', 'direct') - + assert multinomial_loss_type in ("vb_stochastic", "vb_all") + assert parametrization in ("x0", "direct") + if verbose: self.print = print else: self.print = lambda *args, **kwargs: None - if multinomial_loss_type == 'vb_all': - self.print('Computing the loss using the bound on _all_ timesteps.' - ' This is expensive both in terms of memory and computation.') + if multinomial_loss_type == "vb_all": + self.print( + "Computing the loss using the bound on _all_ timesteps." + " This is expensive both in terms of memory and computation." + ) self.num_numerics = num_numerical_features self.num_classes = np.asarray(num_categorical_features) self.num_classes_expanded = torch.from_numpy( - np.concatenate([np.repeat(k, k) for k in num_categorical_features], - dtype=np.float32)).to(device) + np.concatenate( + [np.repeat(k, k) for k in num_categorical_features], dtype=np.float32 + ) + ).to(device) self.dim_input = self.num_numerics + sum(self.num_classes) - + self.slices_for_classes = [np.arange(self.num_classes[0])] offsets = np.cumsum(self.num_classes) for i in range(1, len(offsets)): @@ -106,27 +114,21 @@ def __init__( if model_params is None: model_params = dict( - d_in = self.dim_input, - num_classes = 0, - is_y_cond = False, - rtdl_params = None + d_in=self.dim_input, num_classes=0, is_y_cond=False, rtdl_params=None ) else: - model_params['d_in'] = self.dim_input - - if model_params['rtdl_params'] is None: - model_params['rtdl_params'] = dict( - d_layers = [256, 256, 256], - dropout = 0.0 - ) - - if model_type == 'mlp': + model_params["d_in"] = self.dim_input + + if model_params["rtdl_params"] is None: + model_params["rtdl_params"] = dict(d_layers=[256, 256, 256], dropout=0.0) + + if model_type == "mlp": self.denoise_fn = MLPDiffusion(**model_params) - elif model_type == 'resnet': + elif model_type == "resnet": self.denoise_fn = ResNetDiffusion(**model_params) else: raise "Unknown diffusion model type!" - + self.gaussian_loss_type = gaussian_loss_type self.gaussian_parametrization = gaussian_parametrization self.multinomial_loss_type = multinomial_loss_type @@ -134,9 +136,9 @@ def __init__( self.parametrization = parametrization self.scheduler = scheduler - alphas = 1. - get_named_beta_schedule(scheduler, num_timesteps) - alphas = torch.tensor(alphas.astype('float64')) - betas = 1. - alphas + alphas = 1.0 - get_named_beta_schedule(scheduler, num_timesteps) + alphas = torch.tensor(alphas.astype("float64")) + betas = 1.0 - alphas log_alpha = np.log(alphas) log_cumprod_alpha = np.cumsum(log_alpha) @@ -157,60 +159,86 @@ def __init__( self.posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) - self.posterior_log_variance_clipped = torch.from_numpy( - np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) - ).float().to(device) + self.posterior_log_variance_clipped = ( + torch.from_numpy( + np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + ) + .float() + .to(device) + ) self.posterior_mean_coef1 = ( - betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) - ).float().to(device) + (betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + .float() + .to(device) + ) self.posterior_mean_coef2 = ( - (1.0 - alphas_cumprod_prev) - * np.sqrt(alphas.numpy()) - / (1.0 - alphas_cumprod) - ).float().to(device) + ( + (1.0 - alphas_cumprod_prev) + * np.sqrt(alphas.numpy()) + / (1.0 - alphas_cumprod) + ) + .float() + .to(device) + ) - assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5 - assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 - assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5 + assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.0e-5 + assert ( + log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() + < 1e-5 + ) + assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.0e-5 # Convert to float32 and register buffers. - self.register_buffer('alphas', alphas.float().to(device)) - self.register_buffer('log_alpha', log_alpha.float().to(device)) - self.register_buffer('log_1_min_alpha', log_1_min_alpha.float().to(device)) - self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float().to(device)) - self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float().to(device)) - self.register_buffer('alphas_cumprod', alphas_cumprod.float().to(device)) - self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float().to(device)) - self.register_buffer('alphas_cumprod_next', alphas_cumprod_next.float().to(device)) - self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod.float().to(device)) - self.register_buffer('sqrt_one_minus_alphas_cumprod', sqrt_one_minus_alphas_cumprod.float().to(device)) - self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod.float().to(device)) - self.register_buffer('sqrt_recipm1_alphas_cumprod', sqrt_recipm1_alphas_cumprod.float().to(device)) - - self.register_buffer('Lt_history', torch.zeros(num_timesteps)) - self.register_buffer('Lt_count', torch.zeros(num_timesteps)) - + self.register_buffer("alphas", alphas.float().to(device)) + self.register_buffer("log_alpha", log_alpha.float().to(device)) + self.register_buffer("log_1_min_alpha", log_1_min_alpha.float().to(device)) + self.register_buffer( + "log_1_min_cumprod_alpha", log_1_min_cumprod_alpha.float().to(device) + ) + self.register_buffer("log_cumprod_alpha", log_cumprod_alpha.float().to(device)) + self.register_buffer("alphas_cumprod", alphas_cumprod.float().to(device)) + self.register_buffer( + "alphas_cumprod_prev", alphas_cumprod_prev.float().to(device) + ) + self.register_buffer( + "alphas_cumprod_next", alphas_cumprod_next.float().to(device) + ) + self.register_buffer( + "sqrt_alphas_cumprod", sqrt_alphas_cumprod.float().to(device) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + sqrt_one_minus_alphas_cumprod.float().to(device), + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod.float().to(device) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + sqrt_recipm1_alphas_cumprod.float().to(device), + ) + + self.register_buffer("Lt_history", torch.zeros(num_timesteps)) + self.register_buffer("Lt_count", torch.zeros(num_timesteps)) + # Gaussian part def gaussian_q_mean_variance(self, x_start, t): - mean = ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - ) + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract( - self.log_1_min_cumprod_alpha, t, x_start.shape - ) + log_variance = extract(self.log_1_min_cumprod_alpha, t, x_start.shape) return mean, variance, log_variance - + def gaussian_q_sample(self, x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) assert noise.shape == x_start.shape return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) - + def gaussian_q_posterior_mean_variance(self, x_start, x_t, t): assert x_start.shape == x_t.shape posterior_mean = ( @@ -230,7 +258,13 @@ def gaussian_q_posterior_mean_variance(self, x_start, x_t, t): return posterior_mean, posterior_variance, posterior_log_variance_clipped def gaussian_p_mean_variance( - self, model_output, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None + self, + model_output, + x, + t, + clip_denoised=False, + denoised_fn=None, + model_kwargs=None, ): if model_kwargs is None: model_kwargs = {} @@ -238,27 +272,33 @@ def gaussian_p_mean_variance( B, C = x.shape[:2] assert t.shape == (B,) - model_variance = torch.cat([self.posterior_variance[1].unsqueeze(0).to(x.device), (1. - self.alphas)[1:]], dim=0) + model_variance = torch.cat( + [ + self.posterior_variance[1].unsqueeze(0).to(x.device), + (1.0 - self.alphas)[1:], + ], + dim=0, + ) # model_variance = self.posterior_variance.to(x.device) model_log_variance = torch.log(model_variance) model_variance = extract(model_variance, t, x.shape) model_log_variance = extract(model_log_variance, t, x.shape) - if self.gaussian_parametrization == 'eps': + if self.gaussian_parametrization == "eps": pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) - elif self.gaussian_parametrization == 'x0': + elif self.gaussian_parametrization == "x0": pred_xstart = model_output else: raise NotImplementedError - + model_mean, _, _ = self.gaussian_q_posterior_mean_variance( x_start=pred_xstart, x_t=x, t=t ) assert ( model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape - ), f'{model_mean.shape}, {model_log_variance.shape}, {pred_xstart.shape}, {x.shape}' + ), f"{model_mean.shape}, {model_log_variance.shape}, {pred_xstart.shape}, {x.shape}" return { "mean": model_mean, @@ -266,13 +306,15 @@ def gaussian_p_mean_variance( "log_variance": model_log_variance, "pred_xstart": pred_xstart, } - + def _vb_terms_bpd( self, model_output, x_start, x_t, t, clip_denoised=False, model_kwargs=None ): - true_mean, _, true_log_variance_clipped = self.gaussian_q_posterior_mean_variance( - x_start=x_start, x_t=x_t, t=t - ) + ( + true_mean, + _, + true_log_variance_clipped, + ) = self.gaussian_q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) out = self.gaussian_p_mean_variance( model_output, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) @@ -290,8 +332,13 @@ def _vb_terms_bpd( # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = torch.where((t == 0), decoder_nll, kl) - return {"output": output, "pred_xstart": out["pred_xstart"], "out_mean": out["mean"], "true_mean": true_mean} - + return { + "output": output, + "pred_xstart": out["pred_xstart"], + "out_mean": out["mean"], + "true_mean": true_mean, + } + def _prior_gaussian(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in @@ -309,15 +356,15 @@ def _prior_gaussian(self, x_start): mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) - + def _gaussian_loss(self, model_out, x_start, x_t, t, noise, model_kwargs=None): if model_kwargs is None: model_kwargs = {} terms = {} - if self.gaussian_loss_type == 'mse': + if self.gaussian_loss_type == "mse": terms["loss"] = mean_flat((noise - model_out) ** 2) - elif self.gaussian_loss_type == 'kl': + elif self.gaussian_loss_type == "kl": terms["loss"] = self._vb_terms_bpd( model_output=model_out, x_start=x_start, @@ -327,20 +374,18 @@ def _gaussian_loss(self, model_out, x_start, x_t, t, noise, model_kwargs=None): model_kwargs=model_kwargs, )["output"] + return terms["loss"] - return terms['loss'] - def _predict_xstart_from_eps(self, x_t, t, eps=1e-8): assert x_t.shape == eps.shape return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) - + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - pred_xstart + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def gaussian_p_sample( @@ -365,7 +410,9 @@ def gaussian_p_sample( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 - sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise + sample = ( + out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise + ) return {"sample": sample, "pred_xstart": out["pred_xstart"]} # Multinomial part @@ -381,18 +428,20 @@ def q_pred_one_timestep(self, log_x_t, t): # alpha_t * E[xt] + (1 - alpha_t) 1 / K log_probs = log_add_exp( log_x_t + log_alpha_t, - log_1_min_alpha_t - torch.log(self.num_classes_expanded) + log_1_min_alpha_t - torch.log(self.num_classes_expanded), ) return log_probs def q_pred(self, log_x_start, t): log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape) - log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape) + log_1_min_cumprod_alpha = extract( + self.log_1_min_cumprod_alpha, t, log_x_start.shape + ) log_probs = log_add_exp( log_x_start + log_cumprod_alpha_t, - log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded) + log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded), ) return log_probs @@ -402,7 +451,7 @@ def predict_start(self, model_out, log_x_t): # model_out = self._denoise_fn(x_t, t.to(x_t.device), **out_dict) assert model_out.size(0) == log_x_t.size(0) - assert model_out.size(1) == self.num_classes.sum(), f'{model_out.size()}' + assert model_out.size(1) == self.num_classes.sum(), f"{model_out.size()}" log_pred = torch.empty_like(model_out) for ix in self.slices_for_classes: @@ -425,9 +474,12 @@ def q_posterior(self, log_x_start, log_x_t, t): log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) num_axes = (1,) * (len(log_x_start.size()) - 1) - t_broadcast = t.to(log_x_start.device).view(-1, *num_axes) * torch.ones_like(log_x_start) - log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, - log_EV_qxtmin_x0.to(torch.float32)) + t_broadcast = t.to(log_x_start.device).view(-1, *num_axes) * torch.ones_like( + log_x_start + ) + log_EV_qxtmin_x0 = torch.where( + t_broadcast == 0, log_x_start, log_EV_qxtmin_x0.to(torch.float32) + ) # unnormed_logprobs = log_EV_qxtmin_x0 + # log q_pred_one_timestep(x_t, t) @@ -435,18 +487,19 @@ def q_posterior(self, log_x_start, log_x_t, t): # Not very easy to see why this is true. But it is :) unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) - log_EV_xtmin_given_xt_given_xstart = \ - unnormed_logprobs \ - - sliced_logsumexp(unnormed_logprobs, self.offsets) + log_EV_xtmin_given_xt_given_xstart = unnormed_logprobs - sliced_logsumexp( + unnormed_logprobs, self.offsets + ) return log_EV_xtmin_given_xt_given_xstart def p_pred(self, model_out, log_x, t): - if self.parametrization == 'x0': + if self.parametrization == "x0": log_x_recon = self.predict_start(model_out, log_x) log_model_pred = self.q_posterior( - log_x_start=log_x_recon, log_x_t=log_x, t=t) - elif self.parametrization == 'direct': + log_x_start=log_x_recon, log_x_t=log_x, t=t + ) + elif self.parametrization == "direct": log_model_pred = self.predict_start(model_out, log_x) else: raise ValueError @@ -467,28 +520,32 @@ def p_sample_loop(self, shape): img = torch.randn(shape, device=device) for i in reversed(range(1, self.num_timesteps)): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) + img = self.p_sample( + img, torch.full((b,), i, device=device, dtype=torch.long) + ) return img @torch.no_grad() - def _sample(self, image_size, batch_size = 16): + def _sample(self, image_size, batch_size=16): return self.p_sample_loop((batch_size, 3, image_size, image_size)) - @torch.no_grad() - def interpolate(self, x1, x2, t = None, lam = 0.5): - b, *_, device = *x1.shape, x1.device - t = default(t, self.num_timesteps - 1) + # @torch.no_grad() + # def interpolate(self, x1, x2, t=None, lam=0.5): + # b, *_, device = *x1.shape, x1.device + # t = default(t, self.num_timesteps - 1) - assert x1.shape == x2.shape + # assert x1.shape == x2.shape - t_batched = torch.stack([torch.tensor(t, device=device)] * b) - xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + # t_batched = torch.stack([torch.tensor(t, device=device)] * b) + # xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) - img = (1 - lam) * xt1 + lam * xt2 - for i in reversed(range(0, t)): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) + # img = (1 - lam) * xt1 + lam * xt2 + # for i in reversed(range(0, t)): + # img = self.p_sample( + # img, torch.full((b,), i, device=device, dtype=torch.long) + # ) - return img + # return img def log_sample_categorical(self, logits): full_sample = [] @@ -519,7 +576,8 @@ def nll(self, log_x_start): kl = self.compute_Lt( log_x_start=log_x_start, log_x_t=self.q_sample(log_x_start=log_x_start, t=t_array), - t=t_array) + t=t_array, + ) loss += kl @@ -533,14 +591,15 @@ def kl_prior(self, log_x_start): ones = torch.ones(b, device=device).long() log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) - log_half_prob = -torch.log(self.num_classes_expanded * torch.ones_like(log_qxT_prob)) + log_half_prob = -torch.log( + self.num_classes_expanded * torch.ones_like(log_qxT_prob) + ) kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) return sum_except_batch(kl_prior) def compute_Lt(self, model_out, log_x_start, log_x_t, t, detach_mean=False): - log_true_prob = self.q_posterior( - log_x_start=log_x_start, log_x_t=log_x_t, t=t) + log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t) log_model_prob = self.p_pred(model_out, log_x=log_x_t, t=t) if detach_mean: @@ -553,14 +612,14 @@ def compute_Lt(self, model_out, log_x_start, log_x_t, t, detach_mean=False): decoder_nll = sum_except_batch(decoder_nll) mask = (t == torch.zeros_like(t)).float() - loss = mask * decoder_nll + (1. - mask) * kl + loss = mask * decoder_nll + (1.0 - mask) * kl return loss - def sample_time(self, b, device, method='uniform'): - if method == 'importance': + def sample_time(self, b, device, method="uniform"): + if method == "importance": if not (self.Lt_count > 10).all(): - return self.sample_time(b, device, method='uniform') + return self.sample_time(b, device, method="uniform") Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. @@ -572,7 +631,7 @@ def sample_time(self, b, device, method='uniform'): return t, pt - elif method == 'uniform': + elif method == "uniform": t = torch.randint(0, self.num_timesteps, (b,), device=device).long() pt = torch.ones_like(t).float() / self.num_timesteps @@ -582,17 +641,15 @@ def sample_time(self, b, device, method='uniform'): def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt): - if self.multinomial_loss_type == 'vb_stochastic': - kl = self.compute_Lt( - model_out, log_x_start, log_x_t, t - ) + if self.multinomial_loss_type == "vb_stochastic": + kl = self.compute_Lt(model_out, log_x_start, log_x_t, t) kl_prior = self.kl_prior(log_x_start) # Upweigh loss term of the kl vb_loss = kl / pt + kl_prior return vb_loss - elif self.multinomial_loss_type == 'vb_all': + elif self.multinomial_loss_type == "vb_all": # Expensive, dont do it ;). # DEPRECATED return -self.nll(log_x_start) @@ -602,7 +659,7 @@ def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt): #! Not used def log_prob(self, x): b, device = x.size(0), x.device - + if self.training: #! not enough arguments return self._multinomial_loss(x) @@ -610,10 +667,11 @@ def log_prob(self, x): else: log_x_start = index_to_log_onehot(x, self.num_classes) - t, pt = self.sample_time(b, device, 'importance') + t, pt = self.sample_time(b, device, "importance") kl = self.compute_Lt( - log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t) + log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t + ) kl_prior = self.kl_prior(log_x_start) @@ -621,15 +679,15 @@ def log_prob(self, x): loss = kl / pt + kl_prior return -loss - + def mixed_loss(self, x, cond=None): b = x.shape[0] device = x.device - t, pt = self.sample_time(b, device, 'uniform') + t, pt = self.sample_time(b, device, "uniform") + + x_num = x[:, : self.num_numerics] + x_cat = x[:, self.num_numerics :] - x_num = x[:, :self.num_numerics] - x_cat = x[:, self.num_numerics:] - x_num_t = x_num log_x_cat_t = x_cat if x_num.shape[1] > 0: @@ -638,23 +696,21 @@ def mixed_loss(self, x, cond=None): if x_cat.shape[1] > 0: log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes) log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t) - + x_in = torch.cat([x_num_t, log_x_cat_t], dim=1) - model_out = self.denoise_fn( - x_in, - t, y=cond - ) + model_out = self.denoise_fn(x_in, t, y=cond) - model_out_num = model_out[:, :self.num_numerics] - model_out_cat = model_out[:, self.num_numerics:] + model_out_num = model_out[:, : self.num_numerics] + model_out_cat = model_out[:, self.num_numerics :] loss_multi = torch.zeros((1,)).float() loss_gauss = torch.zeros((1,)).float() if x_cat.shape[1] > 0: - loss_multi = self._multinomial_loss(model_out_cat, log_x_cat, log_x_cat_t, - t, pt) / len(self.num_classes) - + loss_multi = self._multinomial_loss( + model_out_cat, log_x_cat, log_x_cat_t, t, pt + ) / len(self.num_classes) + if x_num.shape[1] > 0: loss_gauss = self._gaussian_loss(model_out_num, x_num, x_num_t, t, noise) @@ -662,14 +718,14 @@ def mixed_loss(self, x, cond=None): # loss_gauss = torch.where(out_dict['y'] == 1, loss_gauss, 2 * loss_gauss) return loss_multi.mean(), loss_gauss.mean() - + @torch.no_grad() def mixed_elbo(self, x0, cond=None): b = x0.size(0) device = x0.device - x_num = x0[:, :self.num_numerics] - x_cat = x0[:, self.num_numerics:] + x_num = x0[:, : self.num_numerics] + x_cat = x0[:, self.num_numerics :] has_cat = x_cat.shape[1] > 0 if has_cat: log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes).to(device) @@ -692,12 +748,11 @@ def mixed_elbo(self, x0, cond=None): log_x_cat_t = x_cat model_out = self.denoise_fn( - torch.cat([x_num_t, log_x_cat_t], dim=1), - t_array, y=cond + torch.cat([x_num_t, log_x_cat_t], dim=1), t_array, y=cond ) - - model_out_num = model_out[:, :self.num_numerics] - model_out_cat = model_out[:, self.num_numerics:] + + model_out_num = model_out[:, : self.num_numerics] + model_out_cat = model_out[:, self.num_numerics :] kl = torch.tensor([0.0]) if has_cat: @@ -713,7 +768,7 @@ def mixed_elbo(self, x0, cond=None): x_start=x_num, x_t=x_num_t, t=t_array, - clip_denoised=False + clip_denoised=False, ) multinomial_loss.append(kl) @@ -751,18 +806,12 @@ def mixed_elbo(self, x0, cond=None): "mse": mse, # "mu_mse": mu_mse "out_mean": out_mean, - "true_mean": true_mean + "true_mean": true_mean, } @torch.no_grad() def gaussian_ddim_step( - self, - model_out_num, - x, - t, - clip_denoised=False, - denoised_fn=None, - eta=0.0 + self, model_out_num, x, t, clip_denoised=False, denoised_fn=None, eta=0.0 ): out = self.gaussian_p_mean_variance( model_out_num, @@ -786,7 +835,7 @@ def gaussian_ddim_step( noise = torch.randn_like(x) mean_pred = ( out["pred_xstart"] * torch.sqrt(alpha_bar_prev) - + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + + torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) @@ -794,39 +843,23 @@ def gaussian_ddim_step( sample = mean_pred + nonzero_mask * sigma * noise return sample - + @torch.no_grad() - def gaussian_ddim_sample( - self, - noise, - T, - cond=None, - eta=0.0 - ): + def gaussian_ddim_sample(self, noise, T, cond=None, eta=0.0): x = noise b = x.shape[0] device = x.device for t in reversed(range(T)): - self.print(f'Sample timestep {t:4d}', end='\r') + self.print(f"Sample timestep {t:4d}", end="\r") t_array = (torch.ones(b, device=device) * t).long() out_num = self.denoise_fn(x, t_array, y=cond) - x = self.gaussian_ddim_step( - out_num, - x, - t_array - ) + x = self.gaussian_ddim_step(out_num, x, t_array) self.print() return x - @torch.no_grad() def gaussian_ddim_reverse_step( - self, - model_out_num, - x, - t, - clip_denoised=False, - eta=0.0 + self, model_out_num, x, t, clip_denoised=False, eta=0.0 ): assert eta == 0.0, "Eta must be zero." out = self.gaussian_p_mean_variance( @@ -839,8 +872,7 @@ def gaussian_ddim_reverse_step( ) eps = ( - extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - - out["pred_xstart"] + extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = extract(self.alphas_cumprod_next, t, x.shape) @@ -852,37 +884,20 @@ def gaussian_ddim_reverse_step( return mean_pred @torch.no_grad() - def gaussian_ddim_reverse_sample( - self, - x, - T, - cond=None - ): + def gaussian_ddim_reverse_sample(self, x, T, cond=None): b = x.shape[0] device = x.device for t in range(T): - self.print(f'Reverse timestep {t:4d}', end='\r') + self.print(f"Reverse timestep {t:4d}", end="\r") t_array = (torch.ones(b, device=device) * t).long() out_num = self.denoise_fn(x, t_array, y=cond) - x = self.gaussian_ddim_reverse_step( - out_num, - x, - t_array, - eta=0.0 - ) + x = self.gaussian_ddim_reverse_step(out_num, x, t_array, eta=0.0) self.print() return x - @torch.no_grad() - def multinomial_ddim_step( - self, - model_out_cat, - log_x_t, - t, - eta=0.0 - ): + def multinomial_ddim_step(self, model_out_cat, log_x_t, t, eta=0.0): # not ddim, essentially log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t) @@ -897,13 +912,15 @@ def multinomial_ddim_step( coef1 = sigma coef2 = alpha_bar_prev - sigma * alpha_bar coef3 = 1 - coef1 - coef2 - - log_ps = torch.stack([ - torch.log(coef1) + log_x_t, - torch.log(coef2) + log_x0, - torch.log(coef3) - torch.log(self.num_classes_expanded) - ], dim=2) + log_ps = torch.stack( + [ + torch.log(coef1) + log_x_t, + torch.log(coef2) + log_x0, + torch.log(coef3) - torch.log(self.num_classes_expanded), + ], + dim=2, + ) log_prob = torch.logsumexp(log_ps, dim=2) @@ -920,7 +937,9 @@ def sample_ddim(self, num_samples, cond=None): has_cat = self.num_classes[0] != 0 log_z = torch.zeros((b, 0), device=device).float() if has_cat: - uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device) + uniform_logits = torch.zeros( + (b, len(self.num_classes_expanded)), device=device + ) log_z = self.log_sample_categorical(uniform_logits) # y = torch.multinomial( @@ -930,15 +949,16 @@ def sample_ddim(self, num_samples, cond=None): # ) # out_dict = {'y': y.long().to(device)} for i in reversed(range(0, self.num_timesteps)): - self.print(f'Sample timestep {i:4d}', end='\r') + self.print(f"Sample timestep {i:4d}", end="\r") t = torch.full((b,), i, device=device, dtype=torch.long) model_out = self.denoise_fn( - torch.cat([z_norm, log_z], dim=1).float(), - t, y=cond + torch.cat([z_norm, log_z], dim=1).float(), t, y=cond + ) + model_out_num = model_out[:, : self.num_numerics] + model_out_cat = model_out[:, self.num_numerics :] + z_norm = self.gaussian_ddim_step( + model_out_num, z_norm, t, clip_denoised=False ) - model_out_num = model_out[:, :self.num_numerics] - model_out_cat = model_out[:, self.num_numerics:] - z_norm = self.gaussian_ddim_step(model_out_num, z_norm, t, clip_denoised=False) if has_cat: log_z = self.multinomial_ddim_step(model_out_cat, log_z, t) @@ -959,7 +979,9 @@ def sample(self, num_samples, cond=None): has_cat = self.num_classes[0] != 0 log_z = torch.zeros((b, 0), device=device).float() if has_cat: - uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device) + uniform_logits = torch.zeros( + (b, len(self.num_classes_expanded)), device=device + ) log_z = self.log_sample_categorical(uniform_logits) # y = torch.multinomial( @@ -969,15 +991,16 @@ def sample(self, num_samples, cond=None): # ) # out_dict = {'y': y.long().to(device)} for i in reversed(range(0, self.num_timesteps)): - self.print(f'Sample timestep {i:4d}', end='\r') + self.print(f"Sample timestep {i:4d}", end="\r") t = torch.full((b,), i, device=device, dtype=torch.long) model_out = self.denoise_fn( - torch.cat([z_norm, log_z], dim=1).float(), - t, y=cond + torch.cat([z_norm, log_z], dim=1).float(), t, y=cond ) - model_out_num = model_out[:, :self.num_numerics] - model_out_cat = model_out[:, self.num_numerics:] - z_norm = self.gaussian_p_sample(model_out_num, z_norm, t, clip_denoised=False)['sample'] + model_out_num = model_out[:, : self.num_numerics] + model_out_cat = model_out[:, self.num_numerics :] + z_norm = self.gaussian_p_sample( + model_out_num, z_norm, t, clip_denoised=False + )["sample"] if has_cat: log_z = self.p_sample(model_out_cat, log_z, t=t) @@ -988,17 +1011,17 @@ def sample(self, num_samples, cond=None): z_cat = ohe_to_categories(z_ohe, self.num_classes) sample = torch.cat([z_norm, z_cat], dim=1).cpu() return sample - + def sample_all(self, num_samples, cond=None, max_batch_size=2000, ddim=False): if ddim: - self.print('Sample using DDIM.') + self.print("Sample using DDIM.") sample_fn = self.sample_ddim else: sample_fn = self.sample bs = np.diff([*range(0, num_samples, max_batch_size), num_samples]) all_samples = [] - + for b in bs: sample = sample_fn(b, cond) if torch.any(sample.isnan()).item(): diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index 44c63884..48310320 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -1,10 +1,14 @@ """ Code was adapted from https://github.com/Yura52/rtdl """ +# mypy: disable-error-code=no-untyped-def +# flake8: noqa: F401 +# stdlib import math from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +# third party import torch import torch.nn as nn import torch.nn.functional as F @@ -13,10 +17,12 @@ ModuleType = Union[str, Callable[..., nn.Module]] + class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) + def timestep_embedding(timesteps, dim, max_period=10000): """ Create sinusoidal timestep embeddings. @@ -29,7 +35,9 @@ def timestep_embedding(timesteps, dim, max_period=10000): """ half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) @@ -37,10 +45,11 @@ def timestep_embedding(timesteps, dim, max_period=10000): embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding + def _is_glu_activation(activation: ModuleType): return ( isinstance(activation, str) - and activation.endswith('GLU') + and activation.endswith("GLU") or activation in [ReGLU, GEGLU] ) @@ -48,6 +57,7 @@ def _is_glu_activation(activation: ModuleType): def _all_or_none(values): assert all(x is None for x in values) or all(x is not None for x in values) + def reglu(x: Tensor) -> Tensor: """The ReGLU activation function from [1]. References: @@ -67,6 +77,7 @@ def geglu(x: Tensor) -> Tensor: a, b = x.chunk(2, dim=-1) return a * F.gelu(b) + class ReGLU(nn.Module): """The ReGLU activation function from [shazeer2020glu]. @@ -102,13 +113,14 @@ class GEGLU(nn.Module): def forward(self, x: Tensor) -> Tensor: return geglu(x) + def _make_nn_module(module_type: ModuleType, *args) -> nn.Module: return ( ( ReGLU() - if module_type == 'ReGLU' + if module_type == "ReGLU" else GEGLU() - if module_type == 'GEGLU' + if module_type == "GEGLU" else getattr(nn, module_type)(*args) ) if isinstance(module_type, str) @@ -174,7 +186,7 @@ def __init__( if isinstance(dropouts, float): dropouts = [dropouts] * len(d_layers) assert len(d_layers) == len(dropouts) - assert activation not in ['ReGLU', 'GEGLU'] + assert activation not in ["ReGLU", "GEGLU"] self.blocks = nn.ModuleList( [ @@ -192,12 +204,12 @@ def __init__( @classmethod def make_baseline( - cls: Type['MLP'], + cls: Type["MLP"], d_in: int, d_layers: List[int], dropout: float, d_out: int, - ) -> 'MLP': + ) -> "MLP": """Create a "baseline" `MLP`. This variation of MLP was used in [gorishniy2021revisiting]. Features: @@ -224,14 +236,14 @@ def make_baseline( assert isinstance(dropout, float) if len(d_layers) > 2: assert len(set(d_layers[1:-1])) == 1, ( - 'if d_layers contains more than two elements, then' - ' all elements except for the first and the last ones must be equal.' + "if d_layers contains more than two elements, then" + " all elements except for the first and the last ones must be equal." ) return MLP( d_in=d_in, - d_layers=d_layers, # type: ignore + d_layers=d_layers, dropouts=dropout, - activation='ReLU', + activation="ReLU", d_out=d_out, ) @@ -335,7 +347,7 @@ def __init__( *, d_in: int, n_blocks: int, - d_main: int, + d_main: Optional[int], d_hidden: int, dropout_first: float, dropout_second: float, @@ -378,7 +390,7 @@ def __init__( @classmethod def make_baseline( - cls: Type['ResNet'], + cls: Type["ResNet"], *, d_in: int, n_blocks: int, @@ -387,7 +399,7 @@ def make_baseline( dropout_first: float, dropout_second: float, d_out: int, - ) -> 'ResNet': + ) -> "ResNet": """Create a "baseline" `ResNet`. This variation of ResNet was used in [gorishniy2021revisiting]. Features: * :code:`Activation` = :code:`ReLU` @@ -409,8 +421,8 @@ def make_baseline( d_hidden=d_hidden, dropout_first=dropout_first, dropout_second=dropout_second, - normalization='BatchNorm1d', - activation='ReLU', + normalization="BatchNorm1d", + activation="ReLU", d_out=d_out, ) @@ -421,10 +433,12 @@ def forward(self, x: Tensor) -> Tensor: x = self.head(x) return x -#### For diffusion + +# **For diffusion** + class MLPDiffusion(nn.Module): - def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t = 128): + def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t=128): super().__init__() self.dim_t = dim_t self.num_classes = num_classes @@ -432,8 +446,8 @@ def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t = 128): # d0 = rtdl_params['d_layers'][0] - rtdl_params['d_in'] = dim_t - rtdl_params['d_out'] = d_in + rtdl_params["d_in"] = dim_t + rtdl_params["d_out"] = d_in self.mlp = MLP.make_baseline(**rtdl_params) @@ -441,14 +455,12 @@ def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t = 128): self.label_emb = nn.Embedding(self.num_classes, dim_t) elif self.num_classes == 0 and is_y_cond: self.label_emb = nn.Linear(1, dim_t) - + self.proj = nn.Linear(d_in, dim_t) self.time_embed = nn.Sequential( - nn.Linear(dim_t, dim_t), - nn.SiLU(), - nn.Linear(dim_t, dim_t) + nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t) ) - + def forward(self, x, timesteps, y=None): emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) if self.is_y_cond and y is not None: @@ -462,28 +474,26 @@ def forward(self, x, timesteps, y=None): class ResNetDiffusion(nn.Module): - def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t = 256): + def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t=256): super().__init__() self.dim_t = dim_t self.num_classes = num_classes - rtdl_params['d_in'] = d_in - rtdl_params['d_out'] = d_in - rtdl_params['emb_d'] = dim_t + rtdl_params["d_in"] = d_in + rtdl_params["d_out"] = d_in + rtdl_params["emb_d"] = dim_t self.resnet = ResNet.make_baseline(**rtdl_params) - + if self.num_classes > 0 and is_y_cond: self.label_emb = nn.Embedding(self.num_classes, dim_t) elif self.num_classes == 0 and is_y_cond: self.label_emb = nn.Linear(1, dim_t) - + self.proj = nn.Linear(d_in, dim_t) self.time_embed = nn.Sequential( - nn.Linear(dim_t, dim_t), - nn.SiLU(), - nn.Linear(dim_t, dim_t) + nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t) ) - + def forward(self, x, timesteps, y=None): emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) if self.is_y_cond and y is not None: diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index ff92f275..61ec9eac 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -1,8 +1,15 @@ -import torch +# mypy: disable-error-code=no-untyped-def + +# stdlib +from inspect import isfunction + +# third party import numpy as np +import torch import torch.nn.functional as F -from torch.profiler import record_function -from inspect import isfunction + +# from torch.profiler import record_function + def normal_kl(mean1, logvar1, mean2, logvar2): """ @@ -33,12 +40,15 @@ def normal_kl(mean1, logvar1, mean2, logvar2): + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) + def approx_standard_normal_cdf(x): """ A fast approximation of the cumulative distribution function of the standard normal. """ - return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) + return 0.5 * ( + 1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))) + ) def discretized_gaussian_log_likelihood(x, *, means, log_scales): @@ -65,13 +75,16 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): log_probs = torch.where( x < -0.999, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + torch.where( + x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)) + ), ) assert log_probs.shape == x.shape return log_probs + def sum_except_batch(x, num_dims=1): - ''' + """ Sums all dimensions except the first. Args: @@ -80,23 +93,26 @@ def sum_except_batch(x, num_dims=1): Returns: x_sum: Tensor, shape (batch_size,) - ''' + """ return x.reshape(*x.shape[:num_dims], -1).sum(-1) + def mean_flat(tensor): """ Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) + def ohe_to_categories(ohe, K): K = torch.from_numpy(K) indices = torch.cat([torch.zeros((1,)), K.cumsum(dim=0)], dim=0).int().tolist() res = [] for i in range(len(indices) - 1): - res.append(ohe[:, indices[i]:indices[i+1]].argmax(dim=1)) + res.append(ohe[:, indices[i] : indices[i + 1]].argmax(dim=1)) return torch.stack(res, dim=1) + def log_1_min_a(a): return torch.log(1 - a.exp() + 1e-40) @@ -105,9 +121,11 @@ def log_add_exp(a, b): maximum = torch.max(a, b) return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) + def exists(x): return x is not None + def extract(a, t, x_shape): b, *_ = t.shape t = t.to(a.device) @@ -116,62 +134,64 @@ def extract(a, t, x_shape): out = out[..., None] return out.expand(x_shape) + def default(val, d): if exists(val): return val return d() if isfunction(d) else d + def log_categorical(log_x_start, log_prob): return (log_x_start.exp() * log_prob).sum(dim=1) + def index_to_log_onehot(x, num_classes): onehots = [] for i in range(len(num_classes)): onehots.append(F.one_hot(x[:, i], num_classes[i])) - x_onehot = torch.cat(onehots, dim=1) log_onehot = torch.log(x_onehot.float().clamp(min=1e-30)) return log_onehot + def log_sum_exp_by_classes(x, slices): - device = x.device res = torch.zeros_like(x) for ixs in slices: res[:, ixs] = torch.logsumexp(x[:, ixs], dim=1, keepdim=True) - assert x.size() == res.size() - return res + @torch.jit.script def log_sub_exp(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: m = torch.maximum(a, b) return torch.log(torch.exp(a - m) - torch.exp(b - m)) + m + @torch.jit.script def sliced_logsumexp(x, slices): lse = torch.logcumsumexp( - torch.nn.functional.pad(x, [1, 0, 0, 0], value=-float('inf')), - dim=-1) + torch.nn.functional.pad(x, [1, 0, 0, 0], value=-float("inf")), dim=-1 + ) slice_starts = slices[:-1] slice_ends = slices[1:] slice_lse = log_sub_exp(lse[:, slice_ends], lse[:, slice_starts]) slice_lse_repeated = torch.repeat_interleave( - slice_lse, - slice_ends - slice_starts, - dim=-1 + slice_lse, slice_ends - slice_starts, dim=-1 ) return slice_lse_repeated + def log_onehot_to_index(log_x): return log_x.argmax(1) class FoundNANsError(BaseException): """Found NANs during sampling""" - def __init__(self, message='Found NANs during sampling.'): + + def __init__(self, message="Found NANs during sampling."): super(FoundNANsError, self).__init__(message) @@ -182,6 +202,7 @@ class TensorDataLoader: the dataset and calls cat (slow). Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 """ + def __init__(self, *tensors, batch_size=32, shuffle=False): """ Initialize a FastTensorDataLoader. @@ -196,15 +217,14 @@ def __init__(self, *tensors, batch_size=32, shuffle=False): self.dataset_len = self.tensors[0].shape[0] self.batch_size = batch_size self.shuffle = shuffle - + def __iter__(self): idx = np.arange(self.dataset_len) if self.shuffle: np.random.shuffle(idx) for i in range(0, self.dataset_len, self.batch_size): - s = idx[i:i+self.batch_size] + s = idx[i : i + self.batch_size] yield tuple(t[s] for t in self.tensors) def __len__(self): return len(range(0, self.dataset_len, self.batch_size)) - \ No newline at end of file diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 638b5e6c..95fb8581 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -13,8 +13,8 @@ # synthcity absolute import synthcity.logger as log -from synthcity.utils.serialization import dataframe_hash from synthcity.utils.dataframe import discrete_columns as find_cat_cols +from synthcity.utils.serialization import dataframe_hash # synthcity relative from .data_encoder import ContinuousDataEncoder diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 3a3da116..36336419 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -1,10 +1,11 @@ """ Reference: Kotelnikov, Akim et al. “TabDDPM: Modelling Tabular Data with Diffusion Models.” ArXiv abs/2209.15421 (2022): n. pag. """ +# mypy: disable-error-code=override +# flake8: noqa: F401 # stdlib from pathlib import Path -from copy import deepcopy from typing import Any, List, Optional, Union # third party @@ -13,8 +14,6 @@ # Necessary packages from pydantic import validate_arguments -import torch -from torch.utils.data import sampler # synthcity absolute from synthcity.metrics.weighted_metrics import WeightedMetrics @@ -26,7 +25,6 @@ IntegerDistribution, ) from synthcity.plugins.core.models.tabular_ddpm import TabDDPM -from synthcity.plugins.core.models.tabular_encoder import TabularEncoder from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema from synthcity.utils.constants import DEVICE @@ -59,14 +57,14 @@ def __init__( self, *, is_classification: bool = False, - n_iter = 1000, - lr = 0.002, - weight_decay = 1e-4, - batch_size = 1024, - model_type = 'mlp', - num_timesteps = 1000, - gaussian_loss_type = 'mse', - scheduler = 'cosine', + n_iter: int = 1000, + lr: float = 0.002, + weight_decay: float = 1e-4, + batch_size: int = 1024, + model_type: str = "mlp", + num_timesteps: int = 1000, + gaussian_loss_type: str = "mse", + scheduler: str = "cosine", device: Any = DEVICE, verbose: int = 0, log_interval: int = 100, @@ -96,13 +94,10 @@ def __init__( compress_dataset=compress_dataset, **kwargs ) - + self.is_classification = is_classification - rtdl_params = dict( - d_layers = [dim_hidden] * num_layers, - dropout = dropout - ) + rtdl_params = dict(d_layers=[dim_hidden] * num_layers, dropout=dropout) self.model = TabDDPM( n_iter=n_iter, lr=lr, @@ -111,16 +106,16 @@ def __init__( num_timesteps=num_timesteps, gaussian_loss_type=gaussian_loss_type, scheduler=scheduler, - device=device, + device=device, verbose=verbose, - log_interval=log_interval, + log_interval=log_interval, print_interval=print_interval, model_type=model_type, - rtdl_params=rtdl_params, + rtdl_params=rtdl_params, dim_label_emb=dim_label_emb, - n_iter_min=n_iter_min, - n_iter_print=n_iter_print, - patience=patience, + n_iter_min=n_iter_min, + n_iter_print=n_iter_print, + patience=patience, ) @staticmethod @@ -158,29 +153,38 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]: CategoricalDistribution(name="dim_hidden", choices=[128, 256, 512, 1024]), ] - def _fit(self, data: DataLoader, cond: pd.Series = None, **kwargs) -> "TabDDPMPlugin": + def _fit( + self, data: DataLoader, cond: Any = None, **kwargs: Any + ) -> "TabDDPMPlugin": if self.is_classification: assert cond is None _, cond = data.unpack() self._labels, self._cond_dist = np.unique(cond, return_counts=True) self._cond_dist = self._cond_dist / self._cond_dist.sum() - - # NOTE: should we include the target column in `data`? - data = data.dataframe() + + # NOTE: should we include the target column in `df`? + df = data.dataframe() if cond is not None: - cond = pd.Series(cond, index=data.index) + cond = pd.Series(cond, index=df.index) # self.encoder = TabularEncoder().fit(X) - - self.model.fit(data, cond, **kwargs) - def _generate(self, count: int, syn_schema: Schema, cond=None, **kwargs: Any) -> DataLoader: + self.model.fit(df, cond, **kwargs) + + return self + + def _generate( + self, count: int, syn_schema: Schema, cond: Any = None, **kwargs: Any + ) -> DataLoader: if self.is_classification and cond is None: # randomly generate labels following the distribution of the training data cond = np.random.choice(self._labels, size=count, p=self._cond_dist) - def callback(count, cond=cond): + + def callback(count, cond=cond): # type: ignore return self.model.generate(count, cond=cond) + return self._safe_generate(callback, count, syn_schema, **kwargs) + plugin = TabDDPMPlugin diff --git a/src/synthcity/utils/dataframe.py b/src/synthcity/utils/dataframe.py index 069b6eab..c12b29da 100644 --- a/src/synthcity/utils/dataframe.py +++ b/src/synthcity/utils/dataframe.py @@ -9,13 +9,15 @@ def constant_columns(dataframe: pd.DataFrame) -> list: return discrete_columns(dataframe, 2) -def discrete_columns(dataframe: pd.DataFrame, - max_classes: int = 10, - return_counts=False) -> list: +def discrete_columns( + dataframe: pd.DataFrame, max_classes: int = 10, return_counts: bool = False +) -> list: """ Find columns containing discrete values in a pandas dataframe. """ - return [(col, cnt) if return_counts else col - for col, vals in dataframe.items() - for cnt in [vals.nunique()] - if cnt < max_classes] + return [ + (col, cnt) if return_counts else col + for col, vals in dataframe.items() + for cnt in [vals.nunique()] + if cnt < max_classes + ] From 2a9aa2af051e3e527ec4a857a8951015a2144457 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Wed, 8 Mar 2023 13:24:45 +0100 Subject: [PATCH 13/95] convert assert to conditional AssertionErrors --- .../gaussian_multinomial_diffsuion.py | 54 ++++++++++++------- .../core/models/tabular_ddpm/modules.py | 27 ++++++---- .../plugins/core/models/tabular_ddpm/utils.py | 15 ++++-- src/synthcity/plugins/generic/plugin_ddpm.py | 5 +- 4 files changed, 66 insertions(+), 35 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 7a2b358d..16498772 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -83,8 +83,10 @@ def __init__( ): super(GaussianMultinomialDiffusion, self).__init__() - assert multinomial_loss_type in ("vb_stochastic", "vb_all") - assert parametrization in ("x0", "direct") + if not (multinomial_loss_type in ("vb_stochastic", "vb_all")): + raise AssertionError + if not (parametrization in ("x0", "direct")): + raise AssertionError if verbose: self.print = print @@ -183,12 +185,15 @@ def __init__( .to(device) ) - assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.0e-5 - assert ( + if not (log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.0e-5): + raise AssertionError + if not ( log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 - ) - assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.0e-5 + ): + raise AssertionError + if not ((np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.0e-5): + raise AssertionError # Convert to float32 and register buffers. self.register_buffer("alphas", alphas.float().to(device)) @@ -233,14 +238,16 @@ def gaussian_q_mean_variance(self, x_start, t): def gaussian_q_sample(self, x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) - assert noise.shape == x_start.shape + if not (noise.shape == x_start.shape): + raise AssertionError return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def gaussian_q_posterior_mean_variance(self, x_start, x_t, t): - assert x_start.shape == x_t.shape + if not (x_start.shape == x_t.shape): + raise AssertionError posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t @@ -249,12 +256,13 @@ def gaussian_q_posterior_mean_variance(self, x_start, x_t, t): posterior_log_variance_clipped = extract( self.posterior_log_variance_clipped, t, x_t.shape ) - assert ( + if not ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] - ) + ): + raise AssertionError return posterior_mean, posterior_variance, posterior_log_variance_clipped def gaussian_p_mean_variance( @@ -270,7 +278,8 @@ def gaussian_p_mean_variance( model_kwargs = {} B, C = x.shape[:2] - assert t.shape == (B,) + if not (t.shape == (B,)): + raise AssertionError model_variance = torch.cat( [ @@ -296,9 +305,12 @@ def gaussian_p_mean_variance( x_start=pred_xstart, x_t=x, t=t ) - assert ( + if not ( model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape - ), f"{model_mean.shape}, {model_log_variance.shape}, {pred_xstart.shape}, {x.shape}" + ): + raise AssertionError( + f"{model_mean.shape}, {model_log_variance.shape}, {pred_xstart.shape}, {x.shape}" + ) return { "mean": model_mean, @@ -326,7 +338,8 @@ def _vb_terms_bpd( decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) - assert decoder_nll.shape == x_start.shape + if not (decoder_nll.shape == x_start.shape): + raise AssertionError decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, @@ -377,7 +390,8 @@ def _gaussian_loss(self, model_out, x_start, x_t, t, noise, model_kwargs=None): return terms["loss"] def _predict_xstart_from_eps(self, x_t, t, eps=1e-8): - assert x_t.shape == eps.shape + if not (x_t.shape == eps.shape): + raise AssertionError return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps @@ -450,8 +464,10 @@ def predict_start(self, model_out, log_x_t): # model_out = self._denoise_fn(x_t, t.to(x_t.device), **out_dict) - assert model_out.size(0) == log_x_t.size(0) - assert model_out.size(1) == self.num_classes.sum(), f"{model_out.size()}" + if not (model_out.size(0) == log_x_t.size(0)): + raise AssertionError + if not (model_out.size(1) == self.num_classes.sum()): + raise AssertionError(f"{model_out.size()}") log_pred = torch.empty_like(model_out) for ix in self.slices_for_classes: @@ -465,7 +481,6 @@ def q_posterior(self, log_x_start, log_x_t, t): # EV_log_qxt_x0 = self.q_pred(log_x_start, t) # self.print('sum exp', EV_log_qxt_x0.exp().sum(1).mean()) - # assert False # log_qxt_x0 = (log_x_t.exp() * EV_log_qxt_x0).sum(dim=1) t_minus_1 = t - 1 @@ -861,7 +876,8 @@ def gaussian_ddim_sample(self, noise, T, cond=None, eta=0.0): def gaussian_ddim_reverse_step( self, model_out_num, x, t, clip_denoised=False, eta=0.0 ): - assert eta == 0.0, "Eta must be zero." + if not (eta == 0.0): + raise AssertionError("Eta must be zero.") out = self.gaussian_p_mean_variance( model_out_num, x, diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index 48310320..00cef021 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -55,7 +55,8 @@ def _is_glu_activation(activation: ModuleType): def _all_or_none(values): - assert all(x is None for x in values) or all(x is not None for x in values) + if not (all(x is None for x in values) or all(x is not None for x in values)): + raise AssertionError def reglu(x: Tensor) -> Tensor: @@ -63,7 +64,8 @@ def reglu(x: Tensor) -> Tensor: References: [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 """ - assert x.shape[-1] % 2 == 0 + if not (x.shape[-1] % 2 == 0): + raise AssertionError a, b = x.chunk(2, dim=-1) return a * F.relu(b) @@ -73,7 +75,8 @@ def geglu(x: Tensor) -> Tensor: References: [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 """ - assert x.shape[-1] % 2 == 0 + if not (x.shape[-1] % 2 == 0): + raise AssertionError a, b = x.chunk(2, dim=-1) return a * F.gelu(b) @@ -185,8 +188,10 @@ def __init__( super().__init__() if isinstance(dropouts, float): dropouts = [dropouts] * len(d_layers) - assert len(d_layers) == len(dropouts) - assert activation not in ["ReGLU", "GEGLU"] + if not (len(d_layers) == len(dropouts)): + raise AssertionError + if activation in ["ReGLU", "GEGLU"]: + raise AssertionError self.blocks = nn.ModuleList( [ @@ -233,12 +238,14 @@ def make_baseline( References: * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 """ - assert isinstance(dropout, float) + if not (isinstance(dropout, float)): + raise AssertionError if len(d_layers) > 2: - assert len(set(d_layers[1:-1])) == 1, ( - "if d_layers contains more than two elements, then" - " all elements except for the first and the last ones must be equal." - ) + if not len(set(d_layers[1:-1])) == 1: + raise AssertionError( + "if d_layers contains more than two elements, then" + " all elements except for the first and the last ones must be equal." + ) return MLP( d_in=d_in, d_layers=d_layers, diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index 61ec9eac..c2491e6e 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -23,7 +23,8 @@ def normal_kl(mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break - assert tensor is not None, "at least one argument must be a Tensor" + if tensor is None: + raise AssertionError("at least one argument must be a Tensor") # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). @@ -62,7 +63,8 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): :param log_scales: the Gaussian log stddev Tensor. :return: a tensor like x of log probabilities (in nats). """ - assert x.shape == means.shape == log_scales.shape + if not (x.shape == means.shape == log_scales.shape): + raise AssertionError centered_x = x - means inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1.0 / 255.0) @@ -79,7 +81,8 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)) ), ) - assert log_probs.shape == x.shape + if not (log_probs.shape == x.shape): + raise AssertionError return log_probs @@ -158,7 +161,8 @@ def log_sum_exp_by_classes(x, slices): res = torch.zeros_like(x) for ixs in slices: res[:, ixs] = torch.logsumexp(x[:, ixs], dim=1, keepdim=True) - assert x.size() == res.size() + if not (x.size() == res.size()): + raise AssertionError return res @@ -212,7 +216,8 @@ def __init__(self, *tensors, batch_size=32, shuffle=False): iterator is created out of this object. :returns: A FastTensorDataLoader. """ - assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) + if not all(t.shape[0] == tensors[0].shape[0] for t in tensors): + raise AssertionError self.tensors = tensors self.dataset_len = self.tensors[0].shape[0] self.batch_size = batch_size diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 36336419..b28c6ef3 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -157,7 +157,10 @@ def _fit( self, data: DataLoader, cond: Any = None, **kwargs: Any ) -> "TabDDPMPlugin": if self.is_classification: - assert cond is None + if cond is not None: + raise ValueError( + "cond is already given by the labels for classification" + ) _, cond = data.unpack() self._labels, self._cond_dist = np.unique(cond, return_counts=True) self._cond_dist = self._cond_dist / self._cond_dist.sum() From 246cd5ba0f3bbbbca2ac0c442472b41a10e715a1 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 10 Mar 2023 15:39:38 +0100 Subject: [PATCH 14/95] added an auto annotation tool --- src/auto-anno.py | 322 +++++++++++++++++++++++++++++++++++++++++++++++ src/tmp.py | 12 ++ 2 files changed, 334 insertions(+) create mode 100644 src/auto-anno.py create mode 100644 src/tmp.py diff --git a/src/auto-anno.py b/src/auto-anno.py new file mode 100644 index 00000000..a56491ca --- /dev/null +++ b/src/auto-anno.py @@ -0,0 +1,322 @@ +import os +import re +import sys +import ast +import runpy +import shutil +import inspect +import argparse +import cloudpickle +from typing import * +from numbers import * +from itertools import product, islice + + +TYPE_MAP = { # maps of type annotations + Integral: int, + Real: float, + Complex: complex, + object: Any +} + +# MOD_MAP = { # maps module names to their common aliases +# 'numpy': 'np', +# 'pandas': 'pd' +# } + + +def get_type(x): + """ + Examples: + >>> get_type(None) + >>> get_type([]) + list + >>> get_type([1, 2, 3]) + list[int] + >>> get_type([1, 'a']) + list + >>> get_type(dict(a=0.9, b=0.1)) + dict[str, float] + >>> get_type(dict(a=0.9, b='a')) + dict[str, typing.Any] + >>> get_type({1, 2.0, None}) + set[typing.Optional[float]] + >>> get_type(str) + type + >>> get_type(True) + bool + >>> get_type((1, 2.0)) + tuple[int, float] + >>> get_type(tuple(range(9))) + tuple[int, ...] + >>> get_type(iter(range(9))) + typing.Iterator[int] + >>> get_type((i if i % 2 else None for i in range(9))) + typing.Iterator[typing.Optional[int]] + """ + def dispatch(T, *xs, maxlen=5): + xs = [list(map(get_type, l)) for l in xs] + if min(map(len, xs)) == 0: # empty collection + return T + ts = tuple(map(get_common_suptype, xs)) + if len(ts) == 1: + t = ts[0] + elif len(ts) > maxlen: + t = get_common_suptype(ts) + else: + t = ts + if t is object: + return T + elif len(ts) > maxlen: + return T[t, ...] + else: + return T[t] + if x is None: + return None + if inspect.isfunction(x) or inspect.ismethod(x): + return Callable + for t in (list, set, frozenset): + if isinstance(x, t): + return dispatch(t, x) + if isinstance(x, tuple): + return dispatch(tuple, *[[a] for a in x], maxlen=4) + if isinstance(x, dict): + return dispatch(dict, x.keys(), x.values()) + if hasattr(x, '__next__'): + return dispatch(Iterator, islice(x, 10)) + if isinstance(x, bool): + return bool + if isinstance(x, Integral): + return Integral + if isinstance(x, Real): + return Real + if isinstance(x, Complex): + return Complex + return type(x) + + +def get_suptypes(t): + def suptypes_of_subscripted_type(t): + T = t.__origin__ + args = t.__args__ + sts = [T[ts] for ts in product(*map(get_suptypes, args)) + if not all(t in (object, ...) for t in ts)] + return sts + T.mro() + if inspect.isclass(t) and issubclass(t, type): + sts = list(t.__mro__) + elif hasattr(t, '__origin__'): + sts = suptypes_of_subscripted_type(t) + elif isinstance(t, type): + sts = list(t.mro()) + elif t == Ellipsis: + sts = [t] + else: # None, Callable, Iterator, etc. + sts = [t, object] + return sts + + +def get_common_suptype(ts, type_map=None): + """Find the most specific common supertype of a collection of types.""" + ts = set(ts) + assert ts, "empty collection of types" + + optional = any(t is None for t in ts) + ts.discard(None) + + if not ts: + return None + + sts = [get_suptypes(t) for t in ts] + for t in min(sts, key=len): + if all(t in ts for ts in sts): + break + else: + return Any + + if type_map: + t = type_map.get(t, t) + if optional: + t = Optional[t] + return t + + +def test(): + def get_anno(xs): + return get_common_suptype(map(get_type, xs)) + recs = [ + [None, 1, 1.2], + [{1: 2}, {1: 2.2}, {1: 2.1, 3: 4}], + [(x for x in range(10)), iter(range(10))], + ] + for xs in recs: + print(get_anno(xs)) + + +def get_full_name(x, global_vars=()): + if x in (None, Ellipsis): + return repr(x) + mod = x.__module__ + try: + name = getattr(x, '__qualname__', x.__name__) + except AttributeError: + print("WARNING: failed to get name of", x, "in", mod) + name = repr(x) + if mod != 'builtins' and x not in global_vars: + name = mod + '.' + name + return name + + +def profiler(frame, event, arg): + if event in ('call', 'return'): + filename = os.path.abspath(frame.f_code.co_filename) + funcname = frame.f_code.co_name + if filename.endswith('.py') and funcname[0] != '<' and CWD in filename: + recs = TYPE_RECS.setdefault(filename, {}) + if 'globals' not in recs: + recs['globals'] = set(frame.f_globals) + if event == 'call': + arg_types = {var: get_type(val) for var, val in frame.f_locals.items()} + lineno = frame.f_lineno + else: + arg_types = {'return': get_type(arg)} + lineno = max(ln for ln, fn in recs if fn == funcname and + ln <= frame.f_lineno and 'return' not in recs[ln, fn]) + rec = recs.setdefault((lineno, funcname), {}) + for k, v in arg_types.items(): + rec.setdefault(k, []).append(v) + return profiler + + +#*** run the script N times to collect type records *** + +parser = argparse.ArgumentParser() +parser.add_argument('script', help='the script to run') +parser.add_argument('-n', type=int, default=1, + help='number of times to run the script') +parser.add_argument('-v', '--verbose', action='store_true') +parser.add_argument('-i', action='store_true', + help='prompt before overwriting each script') +parser.add_argument('--log', default='type_records.pkl', + help='output file for type records') +parser.add_argument('--cwd', default=None, help='working directory') +parser.add_argument('--backup', action='store_true', + help='backup the scripts before annotating them') + +ARGS = parser.parse_args() +DIR = os.path.dirname(os.path.abspath(ARGS.script)) +CWD = ARGS.cwd or DIR + +try: + TYPE_RECS = cloudpickle.load(open(ARGS.log, 'rb')) +except: + TYPE_RECS = {} # {filename: {(lineno, funcname): {argname: [type]}}}} + +sys.path.extend([DIR, CWD]) +sys.setprofile(profiler) + +for _ in range(ARGS.n): + runpy.run_path(sys.argv[1], run_name='__main__') + +sys.setprofile(None) + +with open(ARGS.log, 'wb') as f: + cloudpickle.dump(TYPE_RECS, f) + + +#*** determine the type annotations from the type records *** + +def get_type_annotations(type_records=TYPE_RECS): + def recurse(x): + if isinstance(x, dict): + return {k: recurse(v) for k, v in x.items()} + elif isinstance(x, list): + return get_common_suptype(x, type_map=TYPE_MAP) + else: + return x + return recurse(type_records) + +annotations = get_type_annotations() + +# if ARGS.verbose: +# for path, recs in annotations.items(): +# print(path) +# for (lineno, funcname), arg_types in recs.items(): +# print(f' {funcname} (Ln{lineno}):') +# print(' ' + ', '.join(f'{k}: {get_full_name(v)}' for k, v in arg_types.items())) + + +#*** write the type annotations to the script *** + +def find_defs_in_ast(tree): + def recurse(node): # should be in order + if isinstance(node, ast.FunctionDef): + yield node + for child in ast.iter_child_nodes(node): + yield from recurse(child) + return list(recurse(tree)) + +def annotate_def(def_node, annotations) -> bool: + key = (def_node.lineno, def_node.name) + if key not in annotations: + return False # no type records for this function + annos = annotations[key] + l = def_node.args + all_args = l.posonlyargs + l.args + l.kwonlyargs + changed = False + for a in all_args: + if a.annotation is None and a.arg != 'self': + anno = get_full_name(annos[a.arg], annotations['globals']) + a.annotation = ast.Name(anno) + changed = True + if def_node.returns is None: + anno = get_full_name(annos['return'], annotations['globals']) + def_node.returns = ast.Name(anno) + def_node.returns.lineno = max(a.lineno for a in all_args) + changed = True + return changed + +# def get_aliases(ast): +# # TODO: handle import aliases +# ims = [i for i in ast.body if isinstance(i, ast.ImportFrom)] +# aliases = {} +# for im in ims: + +def annotate_script(filepath, verbose=ARGS.verbose): + s = open(filepath, encoding='utf8').read() + lines = s.splitlines() + defs = [d for d in find_defs_in_ast(ast.parse(s)) + if annotate_def(d, annotations[filepath])] + if not defs: + return None + if verbose: + print('Adding annotations to', filepath, '\n') + starts, ends, sigs = [], [], [] + for node in defs: + ln0, ln1 = node.lineno, node.body[0].lineno + starts.append(ln0 - 1) + ends.append(ln1 - 1) + node.body = [] # only keep signature + line = re.match('\s*', lines[ln0-1])[0] + ast.unparse(node) # keep indentation + sigs.append(line) + if verbose: + print('Old:', *lines[ln0-1:ln1], sep='\n') + print('>' * 50) + print('New:', sigs[-1], sep='\n') + print('-' * 50) + new_lines = [] + for s, e, sig in zip([None] + ends, starts + [None], sigs + [None]): + new_lines.extend(lines[s:e]) + if sig is not None: + new_lines.append(sig) + return '\n'.join(new_lines) + + +for path in annotations: + s = annotate_script(path) + if s is None: + continue + if ARGS.backup: + shutil.copy(path, path + '.bak') + if not ARGS.i or input(f"Overwrite {path}?").lower() == 'y': + with open(path, 'w', encoding='utf8') as f: + f.write(s) diff --git a/src/tmp.py b/src/tmp.py new file mode 100644 index 00000000..74bf87ad --- /dev/null +++ b/src/tmp.py @@ -0,0 +1,12 @@ +from synthcity.plugins import Plugins +from sklearn.datasets import load_iris +from synthcity.plugins.core.dataloader import GenericDataLoader + +# loadDebugger() +X, y = load_iris(as_frame = True, return_X_y = True) +X = GenericDataLoader(X.assign(target = y), target_column="target") +plugin = Plugins().get("ddpm", n_iter=3, is_classification=True, + num_timesteps=100, verbose=1) +plugin.fit(X) +X_syn = plugin.model.generate(50) +print(X_syn) From f458bb45770268bb653129c4ff9ba38091fd8214 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 10 Mar 2023 20:42:46 +0100 Subject: [PATCH 15/95] update auto-anno and generate annotations --- src/auto-anno.py | 263 +++++++++---- .../core/models/tabular_ddpm/__init__.py | 12 +- .../gaussian_multinomial_diffsuion.py | 347 ++++++++---------- .../core/models/tabular_ddpm/modules.py | 46 +-- .../plugins/core/models/tabular_ddpm/utils.py | 72 ++-- src/tmp.py | 19 +- 6 files changed, 414 insertions(+), 345 deletions(-) diff --git a/src/auto-anno.py b/src/auto-anno.py index a56491ca..96225e56 100644 --- a/src/auto-anno.py +++ b/src/auto-anno.py @@ -1,22 +1,30 @@ +# flake8: noqa +# mypy: ignore-errors + +# stdlib +import argparse +import ast +import importlib +import inspect +import io import os import re -import sys -import ast import runpy import shutil -import inspect -import argparse -import cloudpickle -from typing import * +import sys +from collections.abc import Callable, Iterator +from itertools import islice, product from numbers import * -from itertools import product, islice +from typing import Any, Optional, Union +# third party +import cloudpickle TYPE_MAP = { # maps of type annotations Integral: int, Real: float, Complex: complex, - object: Any + object: Any, } # MOD_MAP = { # maps module names to their common aliases @@ -54,9 +62,10 @@ def get_type(x): >>> get_type((i if i % 2 else None for i in range(9))) typing.Iterator[typing.Optional[int]] """ + def dispatch(T, *xs, maxlen=5): xs = [list(map(get_type, l)) for l in xs] - if min(map(len, xs)) == 0: # empty collection + if not xs or min(map(len, xs)) == 0: # empty collection return T ts = tuple(map(get_common_suptype, xs)) if len(ts) == 1: @@ -71,6 +80,7 @@ def dispatch(T, *xs, maxlen=5): return T[t, ...] else: return T[t] + if x is None: return None if inspect.isfunction(x) or inspect.ismethod(x): @@ -82,7 +92,9 @@ def dispatch(T, *xs, maxlen=5): return dispatch(tuple, *[[a] for a in x], maxlen=4) if isinstance(x, dict): return dispatch(dict, x.keys(), x.values()) - if hasattr(x, '__next__'): + if isinstance(x, io.IOBase): + return type(x) + if isinstance(x, Iterator): #! may be too general return dispatch(Iterator, islice(x, 10)) if isinstance(x, bool): return bool @@ -99,12 +111,16 @@ def get_suptypes(t): def suptypes_of_subscripted_type(t): T = t.__origin__ args = t.__args__ - sts = [T[ts] for ts in product(*map(get_suptypes, args)) - if not all(t in (object, ...) for t in ts)] - return sts + T.mro() + sts = [ + T[ts] + for ts in product(*map(get_suptypes, args)) + if not all(t in (object, ...) for t in ts) + ] + return sts + get_suptypes(T) + if inspect.isclass(t) and issubclass(t, type): sts = list(t.__mro__) - elif hasattr(t, '__origin__'): + elif hasattr(t, "__origin__"): sts = suptypes_of_subscripted_type(t) elif isinstance(t, type): sts = list(t.mro()) @@ -119,10 +135,10 @@ def get_common_suptype(ts, type_map=None): """Find the most specific common supertype of a collection of types.""" ts = set(ts) assert ts, "empty collection of types" - + optional = any(t is None for t in ts) ts.discard(None) - + if not ts: return None @@ -132,7 +148,7 @@ def get_common_suptype(ts, type_map=None): break else: return Any - + if type_map: t = type_map.get(t, t) if optional: @@ -143,6 +159,7 @@ def get_common_suptype(ts, type_map=None): def test(): def get_anno(xs): return get_common_suptype(map(get_type, xs)) + recs = [ [None, 1, 1.2], [{1: 2}, {1: 2.2}, {1: 2.1, 3: 4}], @@ -152,62 +169,119 @@ def get_anno(xs): print(get_anno(xs)) -def get_full_name(x, global_vars=()): - if x in (None, Ellipsis): - return repr(x) - mod = x.__module__ - try: - name = getattr(x, '__qualname__', x.__name__) - except AttributeError: - print("WARNING: failed to get name of", x, "in", mod) - name = repr(x) - if mod != 'builtins' and x not in global_vars: - name = mod + '.' + name - return name +def get_full_name(x, global_vars={}): + """ + Examples: + >>> import numpy as np + >>> G = lambda: {id(v): k for k, v in globals().items() if k[0] != '_'} + >>> get_full_name(np.ndarray, G()) + 'np.ndarray' + >>> import scipy as sp + >>> get_full_name(sp.sparse.csr_matrix, G()) + 'sp.sparse.csr_matrix' + >>> import scipy.sparse as sps + >>> get_full_name(sparse.csr_matrix, G()) + 'sps.csr_matrix' + """ + + def get_name(x): + if x.__module__ == "typing": + return x._name + return getattr(x, "__qualname__", x.__name__) + + if x is Ellipsis: + return "..." + if x is None: + return "None" + if id(x) in global_vars: + return global_vars[id(x)] + if x.__module__ == "builtins": + return x.__name__ + # handle the subscripted types + if hasattr(x, "__origin__"): + T, args = x.__origin__, x.__args__ + if T is Union and len(args) == 2 and args[1] is type(None): + T, args = Optional, args[:1] + T = get_full_name(T, global_vars) + args = ", ".join(get_full_name(a, global_vars) for a in args) + return f"{T}[{args}]" + # find the module alias + names = (f"{x.__module__}.{get_name(x)}").split(".")[::-1] + mods = [importlib.import_module(names[-1])] + print(names) + for name in names[-2::-1]: + print(name, mods[-1]) + mods.append(getattr(mods[-1], name)) + mods = mods[::-1] + # find the first module that is imported + for i, (name, mod) in enumerate(zip(names, mods)): + if id(mod) in global_vars: + names = names[:i] + [global_vars[id(mod)]] + mods = mods[: i + 1] + break + # skip useless intermediate modules + for k in range(1, len(names)): + if k >= len(names) - 1: + break + for i, (name, mod) in enumerate(zip(names, mods)): + if i + 1 + k >= len(names): + break + if hasattr(mods[-k], name): + names = names[: i + 1] + names[-k:] + mods = mods[: i + 1] + mods[-k:] + break + return ".".join(names[::-1]) def profiler(frame, event, arg): - if event in ('call', 'return'): + if event in ("call", "return"): filename = os.path.abspath(frame.f_code.co_filename) funcname = frame.f_code.co_name - if filename.endswith('.py') and funcname[0] != '<' and CWD in filename: + if filename.endswith(".py") and funcname[0] != "<" and CWD in filename: recs = TYPE_RECS.setdefault(filename, {}) - if 'globals' not in recs: - recs['globals'] = set(frame.f_globals) - if event == 'call': + if "globals" not in recs: + recs["globals", None] = { + id(v): k for k, v in frame.f_globals.items() if k[0] != "_" + } + if event == "call": + # print(filename, funcname, frame.f_lineno, frame.f_locals) arg_types = {var: get_type(val) for var, val in frame.f_locals.items()} lineno = frame.f_lineno else: - arg_types = {'return': get_type(arg)} - lineno = max(ln for ln, fn in recs if fn == funcname and - ln <= frame.f_lineno and 'return' not in recs[ln, fn]) + arg_types = {"return": get_type(arg)} + #! assumes no nested function has the same name as the outer function + lineno = max( + ln for ln, fn in recs if fn == funcname and ln <= frame.f_lineno + ) rec = recs.setdefault((lineno, funcname), {}) for k, v in arg_types.items(): rec.setdefault(k, []).append(v) return profiler -#*** run the script N times to collect type records *** +# *** run the script N times to collect type records *** parser = argparse.ArgumentParser() -parser.add_argument('script', help='the script to run') -parser.add_argument('-n', type=int, default=1, - help='number of times to run the script') -parser.add_argument('-v', '--verbose', action='store_true') -parser.add_argument('-i', action='store_true', - help='prompt before overwriting each script') -parser.add_argument('--log', default='type_records.pkl', - help='output file for type records') -parser.add_argument('--cwd', default=None, help='working directory') -parser.add_argument('--backup', action='store_true', - help='backup the scripts before annotating them') +parser.add_argument("script", help="the script to run") +parser.add_argument("-n", type=int, default=1, help="number of times to run the script") +parser.add_argument("-v", "--verbose", action="store_true") +parser.add_argument( + "-i", action="store_true", help="prompt before overwriting each script" +) +parser.add_argument( + "--log", default="type_records.pkl", help="output file for type records" +) +parser.add_argument("--cwd", default=None, help="working directory") +parser.add_argument( + "--backup", action="store_true", help="backup the scripts before annotating them" +) ARGS = parser.parse_args() DIR = os.path.dirname(os.path.abspath(ARGS.script)) CWD = ARGS.cwd or DIR try: - TYPE_RECS = cloudpickle.load(open(ARGS.log, 'rb')) + TYPE_RECS = cloudpickle.load(open(ARGS.log, "rb")) except: TYPE_RECS = {} # {filename: {(lineno, funcname): {argname: [type]}}}} @@ -215,15 +289,16 @@ def profiler(frame, event, arg): sys.setprofile(profiler) for _ in range(ARGS.n): - runpy.run_path(sys.argv[1], run_name='__main__') + runpy.run_path(sys.argv[1], run_name="__main__") sys.setprofile(None) -with open(ARGS.log, 'wb') as f: +with open(ARGS.log, "wb") as f: cloudpickle.dump(TYPE_RECS, f) -#*** determine the type annotations from the type records *** +# *** determine the type annotations from the type records *** + def get_type_annotations(type_records=TYPE_RECS): def recurse(x): @@ -233,8 +308,10 @@ def recurse(x): return get_common_suptype(x, type_map=TYPE_MAP) else: return x + return recurse(type_records) + annotations = get_type_annotations() # if ARGS.verbose: @@ -245,7 +322,8 @@ def recurse(x): # print(' ' + ', '.join(f'{k}: {get_full_name(v)}' for k, v in arg_types.items())) -#*** write the type annotations to the script *** +# *** write the type annotations to the script *** + def find_defs_in_ast(tree): def recurse(node): # should be in order @@ -253,62 +331,91 @@ def recurse(node): # should be in order yield node for child in ast.iter_child_nodes(node): yield from recurse(child) + return list(recurse(tree)) -def annotate_def(def_node, annotations) -> bool: + +def annotate_def(def_node: ast.FunctionDef, annotations) -> bool: key = (def_node.lineno, def_node.name) if key not in annotations: return False # no type records for this function annos = annotations[key] - l = def_node.args - all_args = l.posonlyargs + l.args + l.kwonlyargs + A = def_node.args + all_args = A.posonlyargs + A.args + A.kwonlyargs + defaults = dict(zip(A.args + A.kwonlyargs, A.defaults + A.kw_defaults)) + all_args.extend(filter(None, [A.vararg, A.kwarg])) changed = False + global_vars = annotations["globals", None] for a in all_args: - if a.annotation is None and a.arg != 'self': - anno = get_full_name(annos[a.arg], annotations['globals']) + if a.annotation is None and a.arg != "self": + t = annos[a.arg] + if a == A.vararg: + if t is tuple: + t = Any + else: + assert t.__origin__ is tuple + if ( + len(t.__args__) == 1 + or len(t.__args__) == 2 + and t.__args__[1] is Ellipsis + ): + t = t.__args__[0] + else: + t = get_common_suptype(t.__args__) + elif a == A.kwarg: + assert t.__origin__ is dict + t = t.__args__[1] + if t is None: + t = Any + if a.arg in defaults: + t = Union[t, get_type(defaults[a.arg])] + anno = get_full_name(t, global_vars) a.annotation = ast.Name(anno) changed = True if def_node.returns is None: - anno = get_full_name(annos['return'], annotations['globals']) + if "return" not in annos: + print("No return type for", key, annos) + exit() + anno = get_full_name(annos["return"], global_vars) def_node.returns = ast.Name(anno) def_node.returns.lineno = max(a.lineno for a in all_args) changed = True return changed -# def get_aliases(ast): -# # TODO: handle import aliases -# ims = [i for i in ast.body if isinstance(i, ast.ImportFrom)] -# aliases = {} -# for im in ims: def annotate_script(filepath, verbose=ARGS.verbose): - s = open(filepath, encoding='utf8').read() + s = open(filepath, encoding="utf8").read() lines = s.splitlines() - defs = [d for d in find_defs_in_ast(ast.parse(s)) - if annotate_def(d, annotations[filepath])] + defs = [ + d + for d in find_defs_in_ast(ast.parse(s)) + if annotate_def(d, annotations[filepath]) + ] if not defs: return None if verbose: - print('Adding annotations to', filepath, '\n') + print("Adding annotations to", filepath, "\n") starts, ends, sigs = [], [], [] for node in defs: ln0, ln1 = node.lineno, node.body[0].lineno starts.append(ln0 - 1) ends.append(ln1 - 1) node.body = [] # only keep signature - line = re.match('\s*', lines[ln0-1])[0] + ast.unparse(node) # keep indentation + line = re.match(r"\s*", lines[ln0 - 1])[0] + ast.unparse( + node + ) # keep indentation sigs.append(line) if verbose: - print('Old:', *lines[ln0-1:ln1], sep='\n') - print('>' * 50) - print('New:', sigs[-1], sep='\n') - print('-' * 50) + print("Old:", *lines[ln0 - 1 : ln1 - 1], sep="\n") + print(">" * 50) + print("New:", sigs[-1], sep="\n") + print("-" * 50) new_lines = [] for s, e, sig in zip([None] + ends, starts + [None], sigs + [None]): new_lines.extend(lines[s:e]) if sig is not None: new_lines.append(sig) - return '\n'.join(new_lines) + return "\n".join(new_lines) for path in annotations: @@ -316,7 +423,7 @@ def annotate_script(filepath, verbose=ARGS.verbose): if s is None: continue if ARGS.backup: - shutil.copy(path, path + '.bak') - if not ARGS.i or input(f"Overwrite {path}?").lower() == 'y': - with open(path, 'w', encoding='utf8') as f: + shutil.copy(path, path + ".bak") + if not ARGS.i or input(f"Overwrite {path}?").lower() == "y": + with open(path, "w", encoding="utf8") as f: f.write(s) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 98ac9619..1b6df0cd 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -1,9 +1,8 @@ -# mypy: allow-untyped-defs, allow-untyped-calls # flake8: noqa: F401 # stdlib from copy import deepcopy -from typing import Any, Optional, Union +from typing import Any, Iterator, Optional, Union # third party import numpy as np @@ -50,13 +49,18 @@ def __init__( self.__dict__.update(locals()) del self.self - def _anneal_lr(self, epoch): + def _anneal_lr(self, epoch: int) -> None: frac_done = epoch / self.n_iter lr = self.lr * (1 - frac_done) for param_group in self.optimizer.param_groups: param_group["lr"] = lr - def _update_ema(self, target_params, source_params, rate=0.999): + def _update_ema( + self, + target_params: Iterator[nn.Parameter], + source_params: Iterator[nn.Parameter], + rate: float = 0.999, + ) -> None: """ Update target parameters to be closer to those of source parameters using an exponential moving average. diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 16498772..145f81da 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -4,23 +4,26 @@ - https://github.com/ehoogeboom/multinomial_diffusion - https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 """ -# mypy: disable-error-code=no-untyped-def # flake8: noqa: F405 # stdlib import math +from typing import Any, Callable, Optional # third party import numpy as np import torch import torch.nn.functional as F +from torch import Tensor, nn # synthcity relative from .modules import MLPDiffusion, ResNetDiffusion -from .utils import * # noqa: F403 +from .utils import * # noqa: F401, F403 -def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): +def get_named_beta_schedule( + schedule_name: str, num_diffusion_timesteps: int +) -> np.ndarray: """ Get a pre-defined beta schedule for the given name. The beta schedule library consists of beta schedules which remain similar @@ -46,7 +49,9 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): raise NotImplementedError(f"unknown beta schedule: {schedule_name}") -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps: int, alpha_bar: Callable, max_beta: float = 0.999 +) -> np.ndarray: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -68,20 +73,19 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): class GaussianMultinomialDiffusion(torch.nn.Module): def __init__( self, - num_numerical_features, - num_categorical_features, - model_type="mlp", - model_params=None, - num_timesteps=1000, - gaussian_loss_type="mse", - gaussian_parametrization="eps", - multinomial_loss_type="vb_stochastic", - parametrization="x0", - scheduler="cosine", - device=torch.device("cpu"), - verbose=0, - ): - + num_numerical_features: int, + num_categorical_features: tuple, + model_type: str = "mlp", + model_params: Optional[dict] = None, + num_timesteps: int = 1000, + gaussian_loss_type: str = "mse", + gaussian_parametrization: str = "eps", + multinomial_loss_type: str = "vb_stochastic", + parametrization: str = "x0", + scheduler: str = "cosine", + device: torch.device = torch.device("cpu"), + verbose: int = 0, + ) -> None: super(GaussianMultinomialDiffusion, self).__init__() if not (multinomial_loss_type in ("vb_stochastic", "vb_all")): raise AssertionError @@ -129,7 +133,7 @@ def __init__( elif model_type == "resnet": self.denoise_fn = ResNetDiffusion(**model_params) else: - raise "Unknown diffusion model type!" + raise NotImplementedError(f"unknown model type: {model_type}") self.gaussian_loss_type = gaussian_loss_type self.gaussian_parametrization = gaussian_parametrization @@ -229,13 +233,17 @@ def __init__( self.register_buffer("Lt_count", torch.zeros(num_timesteps)) # Gaussian part - def gaussian_q_mean_variance(self, x_start, t): + def gaussian_q_mean_variance( + self, x_start: Tensor, t: Tensor + ) -> tuple[Tensor, Tensor, Tensor]: mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = extract(self.log_1_min_cumprod_alpha, t, x_start.shape) return mean, variance, log_variance - def gaussian_q_sample(self, x_start, t, noise=None): + def gaussian_q_sample( + self, x_start: Tensor, t: Tensor, noise: Optional[Tensor] = None + ) -> Tensor: if noise is None: noise = torch.randn_like(x_start) if not (noise.shape == x_start.shape): @@ -245,7 +253,9 @@ def gaussian_q_sample(self, x_start, t, noise=None): + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) - def gaussian_q_posterior_mean_variance(self, x_start, x_t, t): + def gaussian_q_posterior_mean_variance( + self, x_start: Tensor, x_t: Tensor, t: Tensor + ) -> tuple[Tensor, Tensor, Tensor]: if not (x_start.shape == x_t.shape): raise AssertionError posterior_mean = ( @@ -267,13 +277,13 @@ def gaussian_q_posterior_mean_variance(self, x_start, x_t, t): def gaussian_p_mean_variance( self, - model_output, - x, - t, - clip_denoised=False, - denoised_fn=None, - model_kwargs=None, - ): + model_output: Tensor, + x: Tensor, + t: Tensor, + clip_denoised: bool = False, + denoised_fn: Optional[nn.Module] = None, + model_kwargs: Any = None, + ) -> dict: if model_kwargs is None: model_kwargs = {} @@ -320,8 +330,14 @@ def gaussian_p_mean_variance( } def _vb_terms_bpd( - self, model_output, x_start, x_t, t, clip_denoised=False, model_kwargs=None - ): + self, + model_output: Tensor, + x_start: Tensor, + x_t: Tensor, + t: Tensor, + clip_denoised: bool = False, + model_kwargs: Optional[dict] = None, + ) -> dict: ( true_mean, _, @@ -352,7 +368,7 @@ def _vb_terms_bpd( "true_mean": true_mean, } - def _prior_gaussian(self, x_start): + def _prior_gaussian(self, x_start: Tensor) -> Tensor: """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. @@ -370,7 +386,15 @@ def _prior_gaussian(self, x_start): ) return mean_flat(kl_prior) / np.log(2.0) - def _gaussian_loss(self, model_out, x_start, x_t, t, noise, model_kwargs=None): + def _gaussian_loss( + self, + model_out: Tensor, + x_start: Tensor, + x_t: Tensor, + t: Tensor, + noise: Tensor, + model_kwargs: Any = None, + ) -> Tensor: if model_kwargs is None: model_kwargs = {} @@ -389,7 +413,9 @@ def _gaussian_loss(self, model_out, x_start, x_t, t, noise, model_kwargs=None): return terms["loss"] - def _predict_xstart_from_eps(self, x_t, t, eps=1e-8): + def _predict_xstart_from_eps( + self, x_t: Tensor, t: Tensor, eps: Tensor = 1e-08 + ) -> Tensor: if not (x_t.shape == eps.shape): raise AssertionError return ( @@ -397,20 +423,22 @@ def _predict_xstart_from_eps(self, x_t, t, eps=1e-8): - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + def _predict_eps_from_xstart( + self, x_t: Tensor, t: Tensor, pred_xstart: Tensor + ) -> Tensor: return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def gaussian_p_sample( self, - model_out, - x, - t, - clip_denoised=False, - denoised_fn=None, - model_kwargs=None, - ): + model_out: Tensor, + x: Tensor, + t: Tensor, + clip_denoised: bool = False, + denoised_fn: Any = None, + model_kwargs: Any = None, + ) -> dict: out = self.gaussian_p_mean_variance( model_out, x, @@ -431,11 +459,11 @@ def gaussian_p_sample( # Multinomial part - def multinomial_kl(self, log_prob1, log_prob2): + def multinomial_kl(self, log_prob1: Tensor, log_prob2: Tensor) -> Tensor: kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1) return kl - def q_pred_one_timestep(self, log_x_t, t): + def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor: log_alpha_t = extract(self.log_alpha, t, log_x_t.shape) log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape) @@ -447,7 +475,7 @@ def q_pred_one_timestep(self, log_x_t, t): return log_probs - def q_pred(self, log_x_start, t): + def q_pred(self, log_x_start: Tensor, t: Tensor) -> Tensor: log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape) log_1_min_cumprod_alpha = extract( self.log_1_min_cumprod_alpha, t, log_x_start.shape @@ -460,10 +488,7 @@ def q_pred(self, log_x_start, t): return log_probs - def predict_start(self, model_out, log_x_t): - - # model_out = self._denoise_fn(x_t, t.to(x_t.device), **out_dict) - + def predict_start(self, model_out: Tensor, log_x_t: Tensor) -> Tensor: if not (model_out.size(0) == log_x_t.size(0)): raise AssertionError if not (model_out.size(1) == self.num_classes.sum()): @@ -474,15 +499,7 @@ def predict_start(self, model_out, log_x_t): log_pred[:, ix] = F.log_softmax(model_out[:, ix], dim=1) return log_pred - def q_posterior(self, log_x_start, log_x_t, t): - # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0) - # where q(xt | xt-1, x0) = q(xt | xt-1). - - # EV_log_qxt_x0 = self.q_pred(log_x_start, t) - - # self.print('sum exp', EV_log_qxt_x0.exp().sum(1).mean()) - - # log_qxt_x0 = (log_x_t.exp() * EV_log_qxt_x0).sum(dim=1) + def q_posterior(self, log_x_start: Tensor, log_x_t: Tensor, t: Tensor) -> Tensor: t_minus_1 = t - 1 # Remove negative values, will not be used anyway for final decoder t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) @@ -508,7 +525,7 @@ def q_posterior(self, log_x_start, log_x_t, t): return log_EV_xtmin_given_xt_given_xstart - def p_pred(self, model_out, log_x, t): + def p_pred(self, model_out: Tensor, log_x: Tensor, t: Tensor) -> Tensor: if self.parametrization == "x0": log_x_recon = self.predict_start(model_out, log_x) log_model_pred = self.q_posterior( @@ -521,48 +538,12 @@ def p_pred(self, model_out, log_x, t): return log_model_pred @torch.no_grad() - def p_sample(self, model_out, log_x, t): + def p_sample(self, model_out: Tensor, log_x: Tensor, t: Tensor) -> Tensor: model_log_prob = self.p_pred(model_out, log_x=log_x, t=t) out = self.log_sample_categorical(model_log_prob) return out - @torch.no_grad() - def p_sample_loop(self, shape): - device = self.log_alpha.device - - b = shape[0] - # start with random normal image. - img = torch.randn(shape, device=device) - - for i in reversed(range(1, self.num_timesteps)): - img = self.p_sample( - img, torch.full((b,), i, device=device, dtype=torch.long) - ) - return img - - @torch.no_grad() - def _sample(self, image_size, batch_size=16): - return self.p_sample_loop((batch_size, 3, image_size, image_size)) - - # @torch.no_grad() - # def interpolate(self, x1, x2, t=None, lam=0.5): - # b, *_, device = *x1.shape, x1.device - # t = default(t, self.num_timesteps - 1) - - # assert x1.shape == x2.shape - - # t_batched = torch.stack([torch.tensor(t, device=device)] * b) - # xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) - - # img = (1 - lam) * xt1 + lam * xt2 - # for i in reversed(range(0, t)): - # img = self.p_sample( - # img, torch.full((b,), i, device=device, dtype=torch.long) - # ) - - # return img - - def log_sample_categorical(self, logits): + def log_sample_categorical(self, logits: Tensor) -> Tensor: full_sample = [] for i in range(len(self.num_classes)): one_class_logits = logits[:, self.slices_for_classes[i]] @@ -574,33 +555,14 @@ def log_sample_categorical(self, logits): log_sample = index_to_log_onehot(full_sample, self.num_classes) return log_sample - def q_sample(self, log_x_start, t): + def q_sample(self, log_x_start: Tensor, t: Tensor) -> Tensor: log_EV_qxt_x0 = self.q_pred(log_x_start, t) log_sample = self.log_sample_categorical(log_EV_qxt_x0) return log_sample - def nll(self, log_x_start): - b = log_x_start.size(0) - device = log_x_start.device - loss = 0 - for t in range(0, self.num_timesteps): - t_array = (torch.ones(b, device=device) * t).long() - - kl = self.compute_Lt( - log_x_start=log_x_start, - log_x_t=self.q_sample(log_x_start=log_x_start, t=t_array), - t=t_array, - ) - - loss += kl - - loss += self.kl_prior(log_x_start) - - return loss - - def kl_prior(self, log_x_start): + def kl_prior(self, log_x_start: Tensor) -> Tensor: b = log_x_start.size(0) device = log_x_start.device ones = torch.ones(b, device=device).long() @@ -613,7 +575,14 @@ def kl_prior(self, log_x_start): kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) return sum_except_batch(kl_prior) - def compute_Lt(self, model_out, log_x_start, log_x_t, t, detach_mean=False): + def compute_Lt( + self, + model_out: Tensor, + log_x_start: Tensor, + log_x_t: Tensor, + t: Tensor, + detach_mean: bool = False, + ) -> Tensor: log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t) log_model_prob = self.p_pred(model_out, log_x=log_x_t, t=t) @@ -631,7 +600,9 @@ def compute_Lt(self, model_out, log_x_start, log_x_t, t, detach_mean=False): return loss - def sample_time(self, b, device, method="uniform"): + def sample_time( + self, b: int, device: torch.device, method: str = "uniform" + ) -> tuple: if method == "importance": if not (self.Lt_count > 10).all(): return self.sample_time(b, device, method="uniform") @@ -654,8 +625,14 @@ def sample_time(self, b, device, method="uniform"): else: raise ValueError - def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt): - + def _multinomial_loss( + self, + model_out: Tensor, + log_x_start: Tensor, + log_x_t: Tensor, + t: Tensor, + pt: Tensor, + ) -> Tensor: if self.multinomial_loss_type == "vb_stochastic": kl = self.compute_Lt(model_out, log_x_start, log_x_t, t) kl_prior = self.kl_prior(log_x_start) @@ -671,31 +648,7 @@ def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt): else: raise ValueError() - #! Not used - def log_prob(self, x): - b, device = x.size(0), x.device - - if self.training: - #! not enough arguments - return self._multinomial_loss(x) - - else: - log_x_start = index_to_log_onehot(x, self.num_classes) - - t, pt = self.sample_time(b, device, "importance") - - kl = self.compute_Lt( - log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t - ) - - kl_prior = self.kl_prior(log_x_start) - - # Upweigh loss term of the kl - loss = kl / pt + kl_prior - - return -loss - - def mixed_loss(self, x, cond=None): + def mixed_loss(self, x: Tensor, cond: Optional[Tensor] = None) -> tuple: b = x.shape[0] device = x.device t, pt = self.sample_time(b, device, "uniform") @@ -735,7 +688,7 @@ def mixed_loss(self, x, cond=None): return loss_multi.mean(), loss_gauss.mean() @torch.no_grad() - def mixed_elbo(self, x0, cond=None): + def mixed_elbo(self, x0: Tensor, cond: Optional[Tensor] = None) -> dict: b = x0.size(0) device = x0.device @@ -748,7 +701,7 @@ def mixed_elbo(self, x0, cond=None): gaussian_loss = [] xstart_mse = [] mse = [] - mu_mse = [] + # mu_mse = [] out_mean = [] true_mean = [] multinomial_loss = [] @@ -810,8 +763,8 @@ def mixed_elbo(self, x0, cond=None): if has_cat: prior_multin = self.kl_prior(log_x_cat) - total_gauss = gaussian_loss.sum(dim=1) + prior_gauss - total_multin = multinomial_loss.sum(dim=1) + prior_multin + total_gauss = torch.sum(gaussian_loss, dim=1) + prior_gauss + total_multin = torch.sum(multinomial_loss, dim=1) + prior_multin return { "total_gaussian": total_gauss, "total_multinomial": total_multin, @@ -826,8 +779,14 @@ def mixed_elbo(self, x0, cond=None): @torch.no_grad() def gaussian_ddim_step( - self, model_out_num, x, t, clip_denoised=False, denoised_fn=None, eta=0.0 - ): + self, + model_out_num: Tensor, + x: Tensor, + t: Tensor, + clip_denoised: bool = False, + denoised_fn: Any = None, + eta: float = 0.0, + ) -> Tensor: out = self.gaussian_p_mean_variance( model_out_num, x, @@ -859,23 +818,29 @@ def gaussian_ddim_step( return sample - @torch.no_grad() - def gaussian_ddim_sample(self, noise, T, cond=None, eta=0.0): - x = noise - b = x.shape[0] - device = x.device - for t in reversed(range(T)): - self.print(f"Sample timestep {t:4d}", end="\r") - t_array = (torch.ones(b, device=device) * t).long() - out_num = self.denoise_fn(x, t_array, y=cond) - x = self.gaussian_ddim_step(out_num, x, t_array) - self.print() - return x + # @torch.no_grad() + # def gaussian_ddim_sample(self, noise, T, cond=None, eta=0.0): + # x = noise + # b = x.shape[0] + # device = x.device + # for t in reversed(range(T)): + # self.print(f"Sample timestep {t:4d}", end="\r") + # t_array = (torch.ones(b, device=device) * t).long() + # out_num = self.denoise_fn(x, t_array, y=cond) + # x = self.gaussian_ddim_step(out_num, x, t_array) + # self.print() + # return x @torch.no_grad() def gaussian_ddim_reverse_step( - self, model_out_num, x, t, clip_denoised=False, eta=0.0 - ): + self, + model_out_num: Tensor, + x: Tensor, + t: Tensor, + clip_denoised: bool = False, + denoised_fn: Any = None, + eta: float = 0.0, + ) -> Tensor: if not (eta == 0.0): raise AssertionError("Eta must be zero.") out = self.gaussian_p_mean_variance( @@ -883,7 +848,7 @@ def gaussian_ddim_reverse_step( x, t, clip_denoised=clip_denoised, - denoised_fn=None, + denoised_fn=denoised_fn, model_kwargs=None, ) @@ -899,22 +864,22 @@ def gaussian_ddim_reverse_step( return mean_pred - @torch.no_grad() - def gaussian_ddim_reverse_sample(self, x, T, cond=None): - b = x.shape[0] - device = x.device - for t in range(T): - self.print(f"Reverse timestep {t:4d}", end="\r") - t_array = (torch.ones(b, device=device) * t).long() - out_num = self.denoise_fn(x, t_array, y=cond) - x = self.gaussian_ddim_reverse_step(out_num, x, t_array, eta=0.0) - self.print() - - return x + # @torch.no_grad() + # def gaussian_ddim_reverse_sample(self, x, T, cond=None): + # b = x.shape[0] + # device = x.device + # for t in range(T): + # self.print(f"Reverse timestep {t:4d}", end="\r") + # t_array = (torch.ones(b, device=device) * t).long() + # out_num = self.denoise_fn(x, t_array, y=cond) + # x = self.gaussian_ddim_reverse_step(out_num, x, t_array, eta=0.0) + # self.print() + # return x @torch.no_grad() - def multinomial_ddim_step(self, model_out_cat, log_x_t, t, eta=0.0): - # not ddim, essentially + def multinomial_ddim_step( + self, model_out_cat: Tensor, log_x_t: Tensor, t: Tensor, eta: float = 0.0 + ) -> Tensor: log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t) alpha_bar = extract(self.alphas_cumprod, t, log_x_t.shape) @@ -945,7 +910,7 @@ def multinomial_ddim_step(self, model_out_cat, log_x_t, t, eta=0.0): return out @torch.no_grad() - def sample_ddim(self, num_samples, cond=None): + def sample_ddim(self, num_samples: int, cond: Any = None) -> Tensor: b = num_samples device = self.log_alpha.device z_norm = torch.randn((b, self.num_numerics), device=device) @@ -958,12 +923,6 @@ def sample_ddim(self, num_samples, cond=None): ) log_z = self.log_sample_categorical(uniform_logits) - # y = torch.multinomial( - # cond, - # num_samples=b, - # replacement=True - # ) - # out_dict = {'y': y.long().to(device)} for i in reversed(range(0, self.num_timesteps)): self.print(f"Sample timestep {i:4d}", end="\r") t = torch.full((b,), i, device=device, dtype=torch.long) @@ -987,7 +946,7 @@ def sample_ddim(self, num_samples, cond=None): return sample @torch.no_grad() - def sample(self, num_samples, cond=None): + def sample(self, num_samples: int, cond: Any = None) -> Tensor: b = num_samples device = self.log_alpha.device z_norm = torch.randn((b, self.num_numerics), device=device) @@ -1028,7 +987,13 @@ def sample(self, num_samples, cond=None): sample = torch.cat([z_norm, z_cat], dim=1).cpu() return sample - def sample_all(self, num_samples, cond=None, max_batch_size=2000, ddim=False): + def sample_all( + self, + num_samples: int, + cond: Any = None, + max_batch_size: int = 2000, + ddim: bool = False, + ) -> Tensor: if ddim: self.print("Sample using DDIM.") sample_fn = self.sample_ddim diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index 00cef021..289f37ec 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -1,7 +1,6 @@ """ Code was adapted from https://github.com/Yura52/rtdl """ -# mypy: disable-error-code=no-untyped-def # flake8: noqa: F401 # stdlib @@ -19,11 +18,11 @@ class SiLU(nn.Module): - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: return x * torch.sigmoid(x) -def timestep_embedding(timesteps, dim, max_period=10000): +def timestep_embedding(timesteps: Tensor, dim: int, max_period: int = 10000) -> Tensor: """ Create sinusoidal timestep embeddings. @@ -46,19 +45,6 @@ def timestep_embedding(timesteps, dim, max_period=10000): return embedding -def _is_glu_activation(activation: ModuleType): - return ( - isinstance(activation, str) - and activation.endswith("GLU") - or activation in [ReGLU, GEGLU] - ) - - -def _all_or_none(values): - if not (all(x is None for x in values) or all(x is not None for x in values)): - raise AssertionError - - def reglu(x: Tensor) -> Tensor: """The ReGLU activation function from [1]. References: @@ -117,7 +103,7 @@ def forward(self, x: Tensor) -> Tensor: return geglu(x) -def _make_nn_module(module_type: ModuleType, *args) -> nn.Module: +def _make_nn_module(module_type: ModuleType, *args: Any) -> nn.Module: return ( ( ReGLU() @@ -445,7 +431,14 @@ def forward(self, x: Tensor) -> Tensor: class MLPDiffusion(nn.Module): - def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t=128): + def __init__( + self, + d_in: int, + num_classes: int, + is_y_cond: bool, + rtdl_params: dict, + dim_t: int = 128, + ) -> None: super().__init__() self.dim_t = dim_t self.num_classes = num_classes @@ -468,7 +461,9 @@ def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t=128): nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t) ) - def forward(self, x, timesteps, y=None): + def forward( + self, x: Tensor, timesteps: Tensor, y: Optional[Tensor] = None + ) -> Tensor: emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) if self.is_y_cond and y is not None: if self.num_classes > 0: @@ -481,7 +476,14 @@ def forward(self, x, timesteps, y=None): class ResNetDiffusion(nn.Module): - def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t=256): + def __init__( + self, + d_in: int, + num_classes: int, + is_y_cond: bool, + rtdl_params: dict, + dim_t: int = 256, + ) -> None: super().__init__() self.dim_t = dim_t self.num_classes = num_classes @@ -501,7 +503,9 @@ def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t=256): nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t) ) - def forward(self, x, timesteps, y=None): + def forward( + self, x: Tensor, timesteps: Tensor, y: Optional[Tensor] = None + ) -> Tensor: emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) if self.is_y_cond and y is not None: if self.num_classes > 0: diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index c2491e6e..b495c8a0 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -1,17 +1,16 @@ -# mypy: disable-error-code=no-untyped-def +# flake8: noqa: F401 # stdlib -from inspect import isfunction +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple # third party import numpy as np import torch import torch.nn.functional as F +from torch import Tensor -# from torch.profiler import record_function - -def normal_kl(mean1, logvar1, mean2, logvar2): +def normal_kl(mean1: Tensor, logvar1: Tensor, mean2: Tensor, logvar2: Tensor) -> Tensor: """ Compute the KL divergence between two gaussians. @@ -20,7 +19,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2): """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): + if isinstance(obj, Tensor): tensor = obj break if tensor is None: @@ -29,7 +28,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + x if isinstance(x, Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] @@ -42,7 +41,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2): ) -def approx_standard_normal_cdf(x): +def approx_standard_normal_cdf(x: Tensor) -> Tensor: """ A fast approximation of the cumulative distribution function of the standard normal. @@ -52,7 +51,9 @@ def approx_standard_normal_cdf(x): ) -def discretized_gaussian_log_likelihood(x, *, means, log_scales): +def discretized_gaussian_log_likelihood( + x: Tensor, *, means: Tensor, log_scales: Tensor +) -> Tensor: """ Compute the log-likelihood of a Gaussian distribution discretizing to a given image. @@ -86,7 +87,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): return log_probs -def sum_except_batch(x, num_dims=1): +def sum_except_batch(x: Tensor, num_dims: int = 1) -> Tensor: """ Sums all dimensions except the first. @@ -100,14 +101,14 @@ def sum_except_batch(x, num_dims=1): return x.reshape(*x.shape[:num_dims], -1).sum(-1) -def mean_flat(tensor): +def mean_flat(tensor: Tensor) -> Tensor: """ Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) -def ohe_to_categories(ohe, K): +def ohe_to_categories(ohe: Tensor, K: np.ndarray) -> Tensor: K = torch.from_numpy(K) indices = torch.cat([torch.zeros((1,)), K.cumsum(dim=0)], dim=0).int().tolist() res = [] @@ -116,20 +117,16 @@ def ohe_to_categories(ohe, K): return torch.stack(res, dim=1) -def log_1_min_a(a): +def log_1_min_a(a: Tensor) -> Tensor: return torch.log(1 - a.exp() + 1e-40) -def log_add_exp(a, b): +def log_add_exp(a: Tensor, b: Tensor) -> Tensor: maximum = torch.max(a, b) return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) -def exists(x): - return x is not None - - -def extract(a, t, x_shape): +def extract(a: Tensor, t: Tensor, x_shape: tuple) -> Tensor: b, *_ = t.shape t = t.to(a.device) out = a.gather(-1, t) @@ -138,17 +135,11 @@ def extract(a, t, x_shape): return out.expand(x_shape) -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def log_categorical(log_x_start, log_prob): +def log_categorical(log_x_start: Tensor, log_prob: Tensor) -> Tensor: return (log_x_start.exp() * log_prob).sum(dim=1) -def index_to_log_onehot(x, num_classes): +def index_to_log_onehot(x: Tensor, num_classes: np.ndarray) -> Tensor: onehots = [] for i in range(len(num_classes)): onehots.append(F.one_hot(x[:, i], num_classes[i])) @@ -157,23 +148,14 @@ def index_to_log_onehot(x, num_classes): return log_onehot -def log_sum_exp_by_classes(x, slices): - res = torch.zeros_like(x) - for ixs in slices: - res[:, ixs] = torch.logsumexp(x[:, ixs], dim=1, keepdim=True) - if not (x.size() == res.size()): - raise AssertionError - return res - - @torch.jit.script -def log_sub_exp(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: +def log_sub_exp(a: Tensor, b: Tensor) -> Tensor: m = torch.maximum(a, b) return torch.log(torch.exp(a - m) - torch.exp(b - m)) + m @torch.jit.script -def sliced_logsumexp(x, slices): +def sliced_logsumexp(x: Tensor, slices: Tensor) -> Tensor: lse = torch.logcumsumexp( torch.nn.functional.pad(x, [1, 0, 0, 0], value=-float("inf")), dim=-1 ) @@ -188,14 +170,10 @@ def sliced_logsumexp(x, slices): return slice_lse_repeated -def log_onehot_to_index(log_x): - return log_x.argmax(1) - - class FoundNANsError(BaseException): """Found NANs during sampling""" - def __init__(self, message="Found NANs during sampling."): + def __init__(self, message: str = "Found NANs during sampling.") -> None: super(FoundNANsError, self).__init__(message) @@ -207,7 +185,9 @@ class TensorDataLoader: Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 """ - def __init__(self, *tensors, batch_size=32, shuffle=False): + def __init__( + self, *tensors: Tensor, batch_size: int = 32, shuffle: bool = False + ) -> None: """ Initialize a FastTensorDataLoader. :param *tensors: tensors to store. Must have the same length @ dim 0. @@ -223,7 +203,7 @@ def __init__(self, *tensors, batch_size=32, shuffle=False): self.batch_size = batch_size self.shuffle = shuffle - def __iter__(self): + def __iter__(self) -> Iterator[tuple]: idx = np.arange(self.dataset_len) if self.shuffle: np.random.shuffle(idx) @@ -231,5 +211,5 @@ def __iter__(self): s = idx[i : i + self.batch_size] yield tuple(t[s] for t in self.tensors) - def __len__(self): + def __len__(self) -> int: return len(range(0, self.dataset_len, self.batch_size)) diff --git a/src/tmp.py b/src/tmp.py index 74bf87ad..6a58832b 100644 --- a/src/tmp.py +++ b/src/tmp.py @@ -1,12 +1,21 @@ -from synthcity.plugins import Plugins +# third party from sklearn.datasets import load_iris + +# synthcity absolute +from synthcity.plugins import Plugins from synthcity.plugins.core.dataloader import GenericDataLoader # loadDebugger() -X, y = load_iris(as_frame = True, return_X_y = True) -X = GenericDataLoader(X.assign(target = y), target_column="target") -plugin = Plugins().get("ddpm", n_iter=3, is_classification=True, - num_timesteps=100, verbose=1) +X, y = load_iris(as_frame=True, return_X_y=True) +X = GenericDataLoader(X.assign(target=y), target_column="target") +plugin = Plugins().get( + "ddpm", + n_iter=3, + is_classification=True, + gaussian_loss_type="mse", + num_timesteps=100, + verbose=1, +) plugin.fit(X) X_syn = plugin.model.generate(50) print(X_syn) From 137c1767b508ce6976df26b493e8b5bb70208b38 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 10 Mar 2023 20:50:32 +0100 Subject: [PATCH 16/95] remove auto-anno and flake8 noqa --- src/auto-anno.py | 429 ------------------ .../core/models/tabular_ddpm/__init__.py | 5 +- .../gaussian_multinomial_diffsuion.py | 2 +- .../core/models/tabular_ddpm/modules.py | 14 +- .../plugins/core/models/tabular_ddpm/utils.py | 4 +- src/tmp.py | 21 - 6 files changed, 9 insertions(+), 466 deletions(-) delete mode 100644 src/auto-anno.py delete mode 100644 src/tmp.py diff --git a/src/auto-anno.py b/src/auto-anno.py deleted file mode 100644 index 96225e56..00000000 --- a/src/auto-anno.py +++ /dev/null @@ -1,429 +0,0 @@ -# flake8: noqa -# mypy: ignore-errors - -# stdlib -import argparse -import ast -import importlib -import inspect -import io -import os -import re -import runpy -import shutil -import sys -from collections.abc import Callable, Iterator -from itertools import islice, product -from numbers import * -from typing import Any, Optional, Union - -# third party -import cloudpickle - -TYPE_MAP = { # maps of type annotations - Integral: int, - Real: float, - Complex: complex, - object: Any, -} - -# MOD_MAP = { # maps module names to their common aliases -# 'numpy': 'np', -# 'pandas': 'pd' -# } - - -def get_type(x): - """ - Examples: - >>> get_type(None) - >>> get_type([]) - list - >>> get_type([1, 2, 3]) - list[int] - >>> get_type([1, 'a']) - list - >>> get_type(dict(a=0.9, b=0.1)) - dict[str, float] - >>> get_type(dict(a=0.9, b='a')) - dict[str, typing.Any] - >>> get_type({1, 2.0, None}) - set[typing.Optional[float]] - >>> get_type(str) - type - >>> get_type(True) - bool - >>> get_type((1, 2.0)) - tuple[int, float] - >>> get_type(tuple(range(9))) - tuple[int, ...] - >>> get_type(iter(range(9))) - typing.Iterator[int] - >>> get_type((i if i % 2 else None for i in range(9))) - typing.Iterator[typing.Optional[int]] - """ - - def dispatch(T, *xs, maxlen=5): - xs = [list(map(get_type, l)) for l in xs] - if not xs or min(map(len, xs)) == 0: # empty collection - return T - ts = tuple(map(get_common_suptype, xs)) - if len(ts) == 1: - t = ts[0] - elif len(ts) > maxlen: - t = get_common_suptype(ts) - else: - t = ts - if t is object: - return T - elif len(ts) > maxlen: - return T[t, ...] - else: - return T[t] - - if x is None: - return None - if inspect.isfunction(x) or inspect.ismethod(x): - return Callable - for t in (list, set, frozenset): - if isinstance(x, t): - return dispatch(t, x) - if isinstance(x, tuple): - return dispatch(tuple, *[[a] for a in x], maxlen=4) - if isinstance(x, dict): - return dispatch(dict, x.keys(), x.values()) - if isinstance(x, io.IOBase): - return type(x) - if isinstance(x, Iterator): #! may be too general - return dispatch(Iterator, islice(x, 10)) - if isinstance(x, bool): - return bool - if isinstance(x, Integral): - return Integral - if isinstance(x, Real): - return Real - if isinstance(x, Complex): - return Complex - return type(x) - - -def get_suptypes(t): - def suptypes_of_subscripted_type(t): - T = t.__origin__ - args = t.__args__ - sts = [ - T[ts] - for ts in product(*map(get_suptypes, args)) - if not all(t in (object, ...) for t in ts) - ] - return sts + get_suptypes(T) - - if inspect.isclass(t) and issubclass(t, type): - sts = list(t.__mro__) - elif hasattr(t, "__origin__"): - sts = suptypes_of_subscripted_type(t) - elif isinstance(t, type): - sts = list(t.mro()) - elif t == Ellipsis: - sts = [t] - else: # None, Callable, Iterator, etc. - sts = [t, object] - return sts - - -def get_common_suptype(ts, type_map=None): - """Find the most specific common supertype of a collection of types.""" - ts = set(ts) - assert ts, "empty collection of types" - - optional = any(t is None for t in ts) - ts.discard(None) - - if not ts: - return None - - sts = [get_suptypes(t) for t in ts] - for t in min(sts, key=len): - if all(t in ts for ts in sts): - break - else: - return Any - - if type_map: - t = type_map.get(t, t) - if optional: - t = Optional[t] - return t - - -def test(): - def get_anno(xs): - return get_common_suptype(map(get_type, xs)) - - recs = [ - [None, 1, 1.2], - [{1: 2}, {1: 2.2}, {1: 2.1, 3: 4}], - [(x for x in range(10)), iter(range(10))], - ] - for xs in recs: - print(get_anno(xs)) - - -def get_full_name(x, global_vars={}): - """ - Examples: - >>> import numpy as np - >>> G = lambda: {id(v): k for k, v in globals().items() if k[0] != '_'} - >>> get_full_name(np.ndarray, G()) - 'np.ndarray' - >>> import scipy as sp - >>> get_full_name(sp.sparse.csr_matrix, G()) - 'sp.sparse.csr_matrix' - >>> import scipy.sparse as sps - >>> get_full_name(sparse.csr_matrix, G()) - 'sps.csr_matrix' - """ - - def get_name(x): - if x.__module__ == "typing": - return x._name - return getattr(x, "__qualname__", x.__name__) - - if x is Ellipsis: - return "..." - if x is None: - return "None" - if id(x) in global_vars: - return global_vars[id(x)] - if x.__module__ == "builtins": - return x.__name__ - # handle the subscripted types - if hasattr(x, "__origin__"): - T, args = x.__origin__, x.__args__ - if T is Union and len(args) == 2 and args[1] is type(None): - T, args = Optional, args[:1] - T = get_full_name(T, global_vars) - args = ", ".join(get_full_name(a, global_vars) for a in args) - return f"{T}[{args}]" - # find the module alias - names = (f"{x.__module__}.{get_name(x)}").split(".")[::-1] - mods = [importlib.import_module(names[-1])] - print(names) - for name in names[-2::-1]: - print(name, mods[-1]) - mods.append(getattr(mods[-1], name)) - mods = mods[::-1] - # find the first module that is imported - for i, (name, mod) in enumerate(zip(names, mods)): - if id(mod) in global_vars: - names = names[:i] + [global_vars[id(mod)]] - mods = mods[: i + 1] - break - # skip useless intermediate modules - for k in range(1, len(names)): - if k >= len(names) - 1: - break - for i, (name, mod) in enumerate(zip(names, mods)): - if i + 1 + k >= len(names): - break - if hasattr(mods[-k], name): - names = names[: i + 1] + names[-k:] - mods = mods[: i + 1] + mods[-k:] - break - return ".".join(names[::-1]) - - -def profiler(frame, event, arg): - if event in ("call", "return"): - filename = os.path.abspath(frame.f_code.co_filename) - funcname = frame.f_code.co_name - if filename.endswith(".py") and funcname[0] != "<" and CWD in filename: - recs = TYPE_RECS.setdefault(filename, {}) - if "globals" not in recs: - recs["globals", None] = { - id(v): k for k, v in frame.f_globals.items() if k[0] != "_" - } - if event == "call": - # print(filename, funcname, frame.f_lineno, frame.f_locals) - arg_types = {var: get_type(val) for var, val in frame.f_locals.items()} - lineno = frame.f_lineno - else: - arg_types = {"return": get_type(arg)} - #! assumes no nested function has the same name as the outer function - lineno = max( - ln for ln, fn in recs if fn == funcname and ln <= frame.f_lineno - ) - rec = recs.setdefault((lineno, funcname), {}) - for k, v in arg_types.items(): - rec.setdefault(k, []).append(v) - return profiler - - -# *** run the script N times to collect type records *** - -parser = argparse.ArgumentParser() -parser.add_argument("script", help="the script to run") -parser.add_argument("-n", type=int, default=1, help="number of times to run the script") -parser.add_argument("-v", "--verbose", action="store_true") -parser.add_argument( - "-i", action="store_true", help="prompt before overwriting each script" -) -parser.add_argument( - "--log", default="type_records.pkl", help="output file for type records" -) -parser.add_argument("--cwd", default=None, help="working directory") -parser.add_argument( - "--backup", action="store_true", help="backup the scripts before annotating them" -) - -ARGS = parser.parse_args() -DIR = os.path.dirname(os.path.abspath(ARGS.script)) -CWD = ARGS.cwd or DIR - -try: - TYPE_RECS = cloudpickle.load(open(ARGS.log, "rb")) -except: - TYPE_RECS = {} # {filename: {(lineno, funcname): {argname: [type]}}}} - -sys.path.extend([DIR, CWD]) -sys.setprofile(profiler) - -for _ in range(ARGS.n): - runpy.run_path(sys.argv[1], run_name="__main__") - -sys.setprofile(None) - -with open(ARGS.log, "wb") as f: - cloudpickle.dump(TYPE_RECS, f) - - -# *** determine the type annotations from the type records *** - - -def get_type_annotations(type_records=TYPE_RECS): - def recurse(x): - if isinstance(x, dict): - return {k: recurse(v) for k, v in x.items()} - elif isinstance(x, list): - return get_common_suptype(x, type_map=TYPE_MAP) - else: - return x - - return recurse(type_records) - - -annotations = get_type_annotations() - -# if ARGS.verbose: -# for path, recs in annotations.items(): -# print(path) -# for (lineno, funcname), arg_types in recs.items(): -# print(f' {funcname} (Ln{lineno}):') -# print(' ' + ', '.join(f'{k}: {get_full_name(v)}' for k, v in arg_types.items())) - - -# *** write the type annotations to the script *** - - -def find_defs_in_ast(tree): - def recurse(node): # should be in order - if isinstance(node, ast.FunctionDef): - yield node - for child in ast.iter_child_nodes(node): - yield from recurse(child) - - return list(recurse(tree)) - - -def annotate_def(def_node: ast.FunctionDef, annotations) -> bool: - key = (def_node.lineno, def_node.name) - if key not in annotations: - return False # no type records for this function - annos = annotations[key] - A = def_node.args - all_args = A.posonlyargs + A.args + A.kwonlyargs - defaults = dict(zip(A.args + A.kwonlyargs, A.defaults + A.kw_defaults)) - all_args.extend(filter(None, [A.vararg, A.kwarg])) - changed = False - global_vars = annotations["globals", None] - for a in all_args: - if a.annotation is None and a.arg != "self": - t = annos[a.arg] - if a == A.vararg: - if t is tuple: - t = Any - else: - assert t.__origin__ is tuple - if ( - len(t.__args__) == 1 - or len(t.__args__) == 2 - and t.__args__[1] is Ellipsis - ): - t = t.__args__[0] - else: - t = get_common_suptype(t.__args__) - elif a == A.kwarg: - assert t.__origin__ is dict - t = t.__args__[1] - if t is None: - t = Any - if a.arg in defaults: - t = Union[t, get_type(defaults[a.arg])] - anno = get_full_name(t, global_vars) - a.annotation = ast.Name(anno) - changed = True - if def_node.returns is None: - if "return" not in annos: - print("No return type for", key, annos) - exit() - anno = get_full_name(annos["return"], global_vars) - def_node.returns = ast.Name(anno) - def_node.returns.lineno = max(a.lineno for a in all_args) - changed = True - return changed - - -def annotate_script(filepath, verbose=ARGS.verbose): - s = open(filepath, encoding="utf8").read() - lines = s.splitlines() - defs = [ - d - for d in find_defs_in_ast(ast.parse(s)) - if annotate_def(d, annotations[filepath]) - ] - if not defs: - return None - if verbose: - print("Adding annotations to", filepath, "\n") - starts, ends, sigs = [], [], [] - for node in defs: - ln0, ln1 = node.lineno, node.body[0].lineno - starts.append(ln0 - 1) - ends.append(ln1 - 1) - node.body = [] # only keep signature - line = re.match(r"\s*", lines[ln0 - 1])[0] + ast.unparse( - node - ) # keep indentation - sigs.append(line) - if verbose: - print("Old:", *lines[ln0 - 1 : ln1 - 1], sep="\n") - print(">" * 50) - print("New:", sigs[-1], sep="\n") - print("-" * 50) - new_lines = [] - for s, e, sig in zip([None] + ends, starts + [None], sigs + [None]): - new_lines.extend(lines[s:e]) - if sig is not None: - new_lines.append(sig) - return "\n".join(new_lines) - - -for path in annotations: - s = annotate_script(path) - if s is None: - continue - if ARGS.backup: - shutil.copy(path, path + ".bak") - if not ARGS.i or input(f"Overwrite {path}?").lower() == "y": - with open(path, "w", encoding="utf8") as f: - f.write(s) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 1b6df0cd..91ed41e8 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -1,8 +1,6 @@ -# flake8: noqa: F401 - # stdlib from copy import deepcopy -from typing import Any, Iterator, Optional, Union +from typing import Any, Iterator, Optional # third party import numpy as np @@ -12,7 +10,6 @@ from torch import nn # synthcity absolute -from synthcity.metrics.weighted_metrics import WeightedMetrics from synthcity.utils.constants import DEVICE from synthcity.utils.dataframe import discrete_columns diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 145f81da..fa280dc0 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -18,7 +18,7 @@ # synthcity relative from .modules import MLPDiffusion, ResNetDiffusion -from .utils import * # noqa: F401, F403 +from .utils import * def get_named_beta_schedule( diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index 289f37ec..5ceec2d0 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -1,11 +1,9 @@ """ Code was adapted from https://github.com/Yura52/rtdl """ -# flake8: noqa: F401 - # stdlib import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Callable, Optional, Union # third party import torch @@ -162,8 +160,8 @@ def __init__( self, *, d_in: int, - d_layers: List[int], - dropouts: Union[float, List[float]], + d_layers: list[int], + dropouts: Union[float, list[float]], activation: Union[str, Callable[[], nn.Module]], d_out: int, ) -> None: @@ -195,9 +193,9 @@ def __init__( @classmethod def make_baseline( - cls: Type["MLP"], + cls: type["MLP"], d_in: int, - d_layers: List[int], + d_layers: list[int], dropout: float, d_out: int, ) -> "MLP": @@ -383,7 +381,7 @@ def __init__( @classmethod def make_baseline( - cls: Type["ResNet"], + cls: type["ResNet"], *, d_in: int, n_blocks: int, diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index b495c8a0..4aa59b95 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -1,7 +1,5 @@ -# flake8: noqa: F401 - # stdlib -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple +from typing import Iterator # third party import numpy as np diff --git a/src/tmp.py b/src/tmp.py deleted file mode 100644 index 6a58832b..00000000 --- a/src/tmp.py +++ /dev/null @@ -1,21 +0,0 @@ -# third party -from sklearn.datasets import load_iris - -# synthcity absolute -from synthcity.plugins import Plugins -from synthcity.plugins.core.dataloader import GenericDataLoader - -# loadDebugger() -X, y = load_iris(as_frame=True, return_X_y=True) -X = GenericDataLoader(X.assign(target=y), target_column="target") -plugin = Plugins().get( - "ddpm", - n_iter=3, - is_classification=True, - gaussian_loss_type="mse", - num_timesteps=100, - verbose=1, -) -plugin.fit(X) -X_syn = plugin.model.generate(50) -print(X_syn) From 6c4af11b25f5c58277f352546d4e48aebc7331fe Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 10 Mar 2023 23:37:19 +0100 Subject: [PATCH 17/95] add python<3.9 compatible annotations --- src/synthcity/plugins/core/models/tabular_ddpm/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index 4aa59b95..4d4c92bd 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -1,3 +1,6 @@ +# future +from __future__ import annotations + # stdlib from typing import Iterator From 191cdcc77892aba4099789aa42e734c44199f6bf Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 10 Mar 2023 23:39:53 +0100 Subject: [PATCH 18/95] remove star import --- .../gaussian_multinomial_diffsuion.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index fa280dc0..88925ae2 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -4,8 +4,6 @@ - https://github.com/ehoogeboom/multinomial_diffusion - https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 """ -# flake8: noqa: F405 - # stdlib import math from typing import Any, Callable, Optional @@ -18,7 +16,20 @@ # synthcity relative from .modules import MLPDiffusion, ResNetDiffusion -from .utils import * +from .utils import ( + FoundNANsError, + discretized_gaussian_log_likelihood, + extract, + index_to_log_onehot, + log_1_min_a, + log_add_exp, + log_categorical, + mean_flat, + normal_kl, + ohe_to_categories, + sliced_logsumexp, + sum_except_batch, +) def get_named_beta_schedule( From 9349a66d3518a12b37cad643af76886b9e27fcf8 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 12 Mar 2023 16:10:09 +0100 Subject: [PATCH 19/95] replace builtin type annos to typing annos --- .../core/models/tabular_ddpm/__init__.py | 16 +++++++-------- .../gaussian_multinomial_diffsuion.py | 6 +++--- .../core/models/tabular_ddpm/modules.py | 8 ++++---- src/test.py | 20 +++++++++++++++++++ 4 files changed, 35 insertions(+), 15 deletions(-) create mode 100644 src/test.py diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 91ed41e8..0618becd 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -1,6 +1,7 @@ # stdlib +from collections.abc import Iterator from copy import deepcopy -from typing import Any, Iterator, Optional +from typing import Any, Dict, Optional # third party import numpy as np @@ -24,18 +25,18 @@ def __init__( self, n_iter: int = 1000, lr: float = 0.002, - weight_decay: float = 1e-4, + weight_decay: float = 0.0001, batch_size: int = 1024, num_timesteps: int = 1000, gaussian_loss_type: str = "mse", scheduler: str = "cosine", - device: Any = DEVICE, + device: torch.device = DEVICE, verbose: int = 0, log_interval: int = 10, print_interval: int = 100, # model params model_type: str = "mlp", - rtdl_params: Optional[dict] = None, # {'d_layers', 'dropout'} + rtdl_params: Optional[Dict[str, Any]] = None, dim_label_emb: int = 128, # early stopping n_iter_min: int = 100, @@ -68,7 +69,9 @@ def _update_ema( for targ, src in zip(target_params, source_params): targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate) - def fit(self, X: pd.DataFrame, cond: Any = None, **kwargs: Any) -> "TabDDPM": + def fit( + self, X: pd.DataFrame, cond: Optional[pd.Series] = None, **kwargs: Any + ) -> "TabDDPM": if cond is not None: n_labels = cond.nunique() else: @@ -180,7 +183,4 @@ def generate(self, count: int, cond: Any = None) -> np.ndarray: cond = torch.tensor(cond, dtype=torch.long, device=self.device) sample = self.diffusion.sample_all(count, cond).detach().cpu().numpy() sample = sample[:, self._col_perm] - if self.verbose: - print("Generated sample") - print(sample) return sample diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 88925ae2..ad1a776f 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -6,7 +6,7 @@ """ # stdlib import math -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Tuple # third party import numpy as np @@ -246,7 +246,7 @@ def __init__( # Gaussian part def gaussian_q_mean_variance( self, x_start: Tensor, t: Tensor - ) -> tuple[Tensor, Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Tensor]: mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = extract(self.log_1_min_cumprod_alpha, t, x_start.shape) @@ -266,7 +266,7 @@ def gaussian_q_sample( def gaussian_q_posterior_mean_variance( self, x_start: Tensor, x_t: Tensor, t: Tensor - ) -> tuple[Tensor, Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Tensor]: if not (x_start.shape == x_t.shape): raise AssertionError posterior_mean = ( diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index 5ceec2d0..cc8a56ad 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -3,7 +3,7 @@ """ # stdlib import math -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union # third party import torch @@ -160,8 +160,8 @@ def __init__( self, *, d_in: int, - d_layers: list[int], - dropouts: Union[float, list[float]], + d_layers: List[int], + dropouts: Union[float, List[float]], activation: Union[str, Callable[[], nn.Module]], d_out: int, ) -> None: @@ -195,7 +195,7 @@ def __init__( def make_baseline( cls: type["MLP"], d_in: int, - d_layers: list[int], + d_layers: List[int], dropout: float, d_out: int, ) -> "MLP": diff --git a/src/test.py b/src/test.py new file mode 100644 index 00000000..eb68198a --- /dev/null +++ b/src/test.py @@ -0,0 +1,20 @@ +# third party +from sklearn.datasets import load_iris + +# synthcity absolute +from synthcity.plugins import Plugins +from synthcity.plugins.core.dataloader import GenericDataLoader + +# loadDebugger() +X, y = load_iris(as_frame=True, return_X_y=True) +X = GenericDataLoader(X.assign(target=y), target_column="target") +plugin = Plugins().get( + "ddpm", + n_iter=3, + is_classification=True, + gaussian_loss_type="mse", + num_timesteps=100, + verbose=1, +) +plugin.fit(X) +X_syn = plugin.model.generate(50) From 02579e97e21f325a2deb0c6c3aa0e699f432147d Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 12 Mar 2023 16:54:48 +0100 Subject: [PATCH 20/95] resolve py38 compatibility issue --- src/synthcity/plugins/core/models/tabular_ddpm/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index cc8a56ad..8caad49f 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -193,7 +193,7 @@ def __init__( @classmethod def make_baseline( - cls: type["MLP"], + cls, d_in: int, d_layers: List[int], dropout: float, @@ -381,7 +381,7 @@ def __init__( @classmethod def make_baseline( - cls: type["ResNet"], + cls, *, d_in: int, n_blocks: int, From f930bc0f196c491903121556b755df5ba28dec73 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 12 Mar 2023 22:07:17 +0100 Subject: [PATCH 21/95] tests/plugins/generic/test_ddpm.py --- src/test.py | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 src/test.py diff --git a/src/test.py b/src/test.py deleted file mode 100644 index eb68198a..00000000 --- a/src/test.py +++ /dev/null @@ -1,20 +0,0 @@ -# third party -from sklearn.datasets import load_iris - -# synthcity absolute -from synthcity.plugins import Plugins -from synthcity.plugins.core.dataloader import GenericDataLoader - -# loadDebugger() -X, y = load_iris(as_frame=True, return_X_y=True) -X = GenericDataLoader(X.assign(target=y), target_column="target") -plugin = Plugins().get( - "ddpm", - n_iter=3, - is_classification=True, - gaussian_loss_type="mse", - num_timesteps=100, - verbose=1, -) -plugin.fit(X) -X_syn = plugin.model.generate(50) From 3cf73d7d7d610e1a97a2cde0067830208059ce85 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 13 Mar 2023 10:37:09 +0100 Subject: [PATCH 22/95] change TabDDPM method signatures --- src/synthcity/plugins/generic/plugin_ddpm.py | 36 ++++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index b28c6ef3..6556ec05 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -1,12 +1,10 @@ """ Reference: Kotelnikov, Akim et al. “TabDDPM: Modelling Tabular Data with Diffusion Models.” ArXiv abs/2209.15421 (2022): n. pag. """ -# mypy: disable-error-code=override -# flake8: noqa: F401 # stdlib from pathlib import Path -from typing import Any, List, Optional, Union +from typing import Any, List # third party import numpy as np @@ -16,14 +14,8 @@ from pydantic import validate_arguments # synthcity absolute -from synthcity.metrics.weighted_metrics import WeightedMetrics from synthcity.plugins.core.dataloader import DataLoader -from synthcity.plugins.core.distribution import ( - CategoricalDistribution, - Distribution, - FloatDistribution, - IntegerDistribution, -) +from synthcity.plugins.core.distribution import CategoricalDistribution, Distribution from synthcity.plugins.core.models.tabular_ddpm import TabDDPM from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema @@ -153,20 +145,28 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]: CategoricalDistribution(name="dim_hidden", choices=[128, 256, 512, 1024]), ] - def _fit( - self, data: DataLoader, cond: Any = None, **kwargs: Any - ) -> "TabDDPMPlugin": + def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": + cond = None + if args: + if len(args) > 1: + raise ValueError("Only one positional argument is allowed") + if "cond" in kwargs: + raise ValueError("cond is already given by the positional argument") + cond = args[0] + elif "cond" in kwargs: + cond = kwargs.pop("cond") + if self.is_classification: if cond is not None: raise ValueError( "cond is already given by the labels for classification" ) - _, cond = data.unpack() + _, cond = X.unpack() self._labels, self._cond_dist = np.unique(cond, return_counts=True) self._cond_dist = self._cond_dist / self._cond_dist.sum() # NOTE: should we include the target column in `df`? - df = data.dataframe() + df = X.dataframe() if cond is not None: cond = pd.Series(cond, index=df.index) @@ -177,9 +177,9 @@ def _fit( return self - def _generate( - self, count: int, syn_schema: Schema, cond: Any = None, **kwargs: Any - ) -> DataLoader: + def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader: + cond = kwargs.pop("cond", None) + if self.is_classification and cond is None: # randomly generate labels following the distribution of the training data cond = np.random.choice(self._labels, size=count, p=self._cond_dist) From 5d37c4ba5ba7e1d4a56e783aa7a1b8aaa24f1bcd Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 13 Mar 2023 13:09:30 +0100 Subject: [PATCH 23/95] remove Iterator subscription --- src/synthcity/plugins/core/models/tabular_ddpm/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 0618becd..35910001 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -55,8 +55,8 @@ def _anneal_lr(self, epoch: int) -> None: def _update_ema( self, - target_params: Iterator[nn.Parameter], - source_params: Iterator[nn.Parameter], + target_params: Iterator, + source_params: Iterator, rate: float = 0.999, ) -> None: """ From 681ba607bc7c4052989e7791be8f9dacfe329d12 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Wed, 15 Mar 2023 17:09:25 +0100 Subject: [PATCH 24/95] update AssertionErrors, add EarlyStop callback, removed additional MLP, update logging --- .gitignore | 1 + src/synthcity/plugins/core/models/mlp.py | 37 +- .../core/models/tabular_ddpm/__init__.py | 57 +- .../gaussian_multinomial_diffsuion.py | 249 +++----- .../core/models/tabular_ddpm/modules.py | 546 +++--------------- .../plugins/core/models/tabular_ddpm/utils.py | 35 +- src/synthcity/plugins/generic/plugin_ddpm.py | 109 ++-- src/synthcity/utils/callbacks.py | 91 +++ src/synthcity/utils/dataframe.py | 4 +- tests/plugins/generic/test_ddpm.py | 9 +- 10 files changed, 416 insertions(+), 722 deletions(-) create mode 100644 src/synthcity/utils/callbacks.py diff --git a/.gitignore b/.gitignore index 41f36b84..b2bc0daa 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,4 @@ lightning_logs generated MNIST cifar-10* +src/test.py diff --git a/src/synthcity/plugins/core/models/mlp.py b/src/synthcity/plugins/core/models/mlp.py index eb599874..5ab63464 100644 --- a/src/synthcity/plugins/core/models/mlp.py +++ b/src/synthcity/plugins/core/models/mlp.py @@ -1,11 +1,11 @@ # stdlib -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union # third party import numpy as np import torch from pydantic import validate_arguments -from torch import nn +from torch import Tensor, nn from torch.utils.data import DataLoader, TensorDataset # synthcity absolute @@ -31,8 +31,27 @@ def forward(self, logits: torch.Tensor) -> torch.Tensor: ) -def get_nonlin(name: str) -> nn.Module: - if name == "none": +class GLU(nn.Module): + """Gated Linear Unit (GLU).""" + + def __init__(self, activation: Union[str, nn.Module] = "sigmoid") -> None: + super().__init__() + if type(activation) == str: + self.non_lin = get_nonlin(activation) + else: + self.non_lin = activation + + def forward(self, x: Tensor) -> Tensor: + if x.shape[-1] % 2: + raise ValueError("The last dimension of the input tensor must be even.") + a, b = x.chunk(2, dim=-1) + return a * self.non_lin(b) + + +def get_nonlin(name: Union[str, nn.Module]) -> nn.Module: + if isinstance(name, nn.Module): + return name + elif name == "none": return nn.Identity() elif name == "elu": return nn.ELU() @@ -48,6 +67,16 @@ def get_nonlin(name: str) -> nn.Module: return nn.Sigmoid() elif name == "softmax": return GumbelSoftmax() + elif name == "gelu": + return nn.GELU() + elif name == "glu": + return GLU() + elif name == "reglu": + return GLU("relu") + elif name == "geglu": + return GLU("gelu") + elif name in ("silu", "swish"): + return nn.SiLU() else: raise ValueError(f"Unknown nonlinearity {name}") diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 35910001..c762f389 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -1,7 +1,7 @@ # stdlib from collections.abc import Iterator from copy import deepcopy -from typing import Any, Dict, Optional +from typing import Any, Optional, Sequence # third party import numpy as np @@ -11,6 +11,9 @@ from torch import nn # synthcity absolute +from synthcity.logger import info +from synthcity.metrics.weighted_metrics import WeightedMetrics +from synthcity.utils.callbacks import Callback from synthcity.utils.constants import DEVICE from synthcity.utils.dataframe import discrete_columns @@ -28,20 +31,21 @@ def __init__( weight_decay: float = 0.0001, batch_size: int = 1024, num_timesteps: int = 1000, + is_classification: bool = False, gaussian_loss_type: str = "mse", scheduler: str = "cosine", + callbacks: Sequence[Callback] = (), device: torch.device = DEVICE, - verbose: int = 0, log_interval: int = 10, print_interval: int = 100, # model params model_type: str = "mlp", - rtdl_params: Optional[Dict[str, Any]] = None, - dim_label_emb: int = 128, + mlp_params: Optional[dict] = None, + dim_embed: int = 128, # early stopping n_iter_min: int = 100, - n_iter_print: int = 50, patience: int = 5, + patience_metric: Optional[WeightedMetrics] = None, ) -> None: super().__init__() self.__dict__.update(locals()) @@ -72,10 +76,12 @@ def _update_ema( def fit( self, X: pd.DataFrame, cond: Optional[pd.Series] = None, **kwargs: Any ) -> "TabDDPM": - if cond is not None: - n_labels = cond.nunique() + if self.is_classification and cond is not None: + if np.ndim(cond) != 1: + raise ValueError("cond must be a 1D array") + self.n_classes = cond.nunique() else: - n_labels = 0 + self.n_classes = 0 cat_cols = discrete_columns(X, return_counts=True) @@ -92,10 +98,10 @@ def fit( self._col_perm = np.arange(X.shape[1]) model_params = dict( - num_classes=n_labels, - is_y_cond=cond is not None, - rtdl_params=self.rtdl_params, - dim_t=self.dim_label_emb, + num_classes=self.n_classes, + use_label=cond is not None, + mlp_params=self.mlp_params, + dim_emb=self.dim_embed, ) tensors = [ @@ -104,6 +110,7 @@ def fit( if cond is None else torch.tensor(cond.values, dtype=torch.long, device=self.device), ] + self.dataloader = TensorDataLoader(*tensors, batch_size=self.batch_size) self.diffusion = GaussianMultinomialDiffusion( @@ -115,7 +122,6 @@ def fit( num_timesteps=self.num_timesteps, scheduler=self.scheduler, device=self.device, - verbose=self.verbose, ).to(self.device) self.ema_model = deepcopy(self.diffusion.denoise_fn) @@ -126,11 +132,10 @@ def fit( self.diffusion.parameters(), lr=self.lr, weight_decay=self.weight_decay ) - self.loss_history = pd.DataFrame(columns=["step", "mloss", "gloss", "loss"]) + for cbk in self.callbacks: + cbk.on_fit_begin(self) - # if self.verbose: - # print("Starting training") - # print(self) + self.loss_history = pd.DataFrame(columns=["step", "mloss", "gloss", "loss"]) steps = 0 curr_loss_multi = 0.0 @@ -138,8 +143,11 @@ def fit( curr_count = 0 for epoch in range(self.n_iter): + self.epoch = epoch + 1 self.diffusion.train() + [cbk.on_epoch_begin(self, epoch) for cbk in self.callbacks] + for x, y in self.dataloader: self.optimizer.zero_grad() loss_multi, loss_gauss = self.diffusion.mixed_loss(x, y) @@ -157,8 +165,8 @@ def fit( if steps % self.log_interval == 0: mloss = np.around(curr_loss_multi / curr_count, 4) gloss = np.around(curr_loss_gauss / curr_count, 4) - if self.verbose and steps % self.print_interval == 0: - print( + if steps % self.print_interval == 0: + info( f"Step {steps}: MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}" ) self.loss_history.loc[len(self.loss_history)] = [ @@ -175,6 +183,17 @@ def fit( self.ema_model.parameters(), self.diffusion.parameters() ) + self.eval() + + try: + [cbk.on_epoch_end(self, epoch) for cbk in self.callbacks] + except StopIteration: + info(f"Early stopped at epoch {epoch}") + break + + for cbk in self.callbacks: + cbk.on_fit_end(self) + return self def generate(self, count: int, cond: Any = None) -> np.ndarray: diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index ad1a776f..25bc57f1 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -6,18 +6,20 @@ """ # stdlib import math -from typing import Any, Callable, Optional, Tuple +from typing import Any, Optional, Tuple # third party import numpy as np import torch import torch.nn.functional as F -from torch import Tensor, nn +from torch import Tensor + +# synthcity absolute +from synthcity.logger import debug, info, warning # synthcity relative from .modules import MLPDiffusion, ResNetDiffusion from .utils import ( - FoundNANsError, discretized_gaussian_log_likelihood, extract, index_to_log_onehot, @@ -32,16 +34,7 @@ ) -def get_named_beta_schedule( - schedule_name: str, num_diffusion_timesteps: int -) -> np.ndarray: - """ - Get a pre-defined beta schedule for the given name. - The beta schedule library consists of beta schedules which remain similar - in the limit of num_diffusion_timesteps. - Beta schedules may be added, but should not be removed or changed once - they are committed to maintain backwards compatibility. - """ +def get_beta_schedule(schedule_name: str, num_diffusion_timesteps: int) -> np.ndarray: if schedule_name == "linear": # Linear schedule from Ho et al, extended to work for any number of # diffusion steps. @@ -52,35 +45,25 @@ def get_named_beta_schedule( beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 ) elif schedule_name == "cosine": - return betas_for_alpha_bar( - num_diffusion_timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) + # Create a beta schedule that discretizes the given alpha_t_bar function, + # which defines the cumulative product of (1-beta) over time from t = [0,1]. + def alpha_bar(t: float) -> float: + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + # a lambda that takes an argument t between 0 and 1 and produces the cumulative + # product of (1-beta) up to that part of the diffusion process. + max_beta = 0.999 + # the maximum beta to use; use values lower than 1 to prevent singularities. + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) else: raise NotImplementedError(f"unknown beta schedule: {schedule_name}") -def betas_for_alpha_bar( - num_diffusion_timesteps: int, alpha_bar: Callable, max_beta: float = 0.999 -) -> np.ndarray: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - class GaussianMultinomialDiffusion(torch.nn.Module): def __init__( self, @@ -95,21 +78,19 @@ def __init__( parametrization: str = "x0", scheduler: str = "cosine", device: torch.device = torch.device("cpu"), - verbose: int = 0, ) -> None: super(GaussianMultinomialDiffusion, self).__init__() - if not (multinomial_loss_type in ("vb_stochastic", "vb_all")): - raise AssertionError - if not (parametrization in ("x0", "direct")): - raise AssertionError - - if verbose: - self.print = print - else: - self.print = lambda *args, **kwargs: None + if multinomial_loss_type not in ("vb_stochastic", "vb_all"): + raise ValueError( + "multinomial_loss_type must be 'vb_stochastic' or 'vb_all'" + ) + if gaussian_loss_type not in ("mse", "kl"): + raise ValueError("gaussian_loss_type must be 'mse' or 'kl'") + if parametrization not in ("x0", "direct"): + raise ValueError("parametrization must be 'x0' or 'direct'") if multinomial_loss_type == "vb_all": - self.print( + warning( "Computing the loss using the bound on _all_ timesteps." " This is expensive both in terms of memory and computation." ) @@ -131,13 +112,15 @@ def __init__( if model_params is None: model_params = dict( - d_in=self.dim_input, num_classes=0, is_y_cond=False, rtdl_params=None + dim_in=self.dim_input, num_classes=0, use_label=False, mlp_params=None ) else: - model_params["d_in"] = self.dim_input + model_params["dim_in"] = self.dim_input - if model_params["rtdl_params"] is None: - model_params["rtdl_params"] = dict(d_layers=[256, 256, 256], dropout=0.0) + if model_params["mlp_params"] is None: + model_params["mlp_params"] = dict( + n_units_hidden=256, n_layers_hidden=3, dropout=0.0 + ) if model_type == "mlp": self.denoise_fn = MLPDiffusion(**model_params) @@ -153,7 +136,7 @@ def __init__( self.parametrization = parametrization self.scheduler = scheduler - alphas = 1.0 - get_named_beta_schedule(scheduler, num_timesteps) + alphas = 1.0 - get_beta_schedule(scheduler, num_timesteps) alphas = torch.tensor(alphas.astype("float64")) betas = 1.0 - alphas @@ -200,15 +183,18 @@ def __init__( .to(device) ) - if not (log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.0e-5): - raise AssertionError - if not ( - log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() - < 1e-5 + if ( + max( + log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item(), + log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha) + .abs() + .sum() + .item(), + (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item(), + ) + > 1e-5 ): - raise AssertionError - if not ((np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.0e-5): - raise AssertionError + raise ValueError("Numerical error in log-sum-exp") # Convert to float32 and register buffers. self.register_buffer("alphas", alphas.float().to(device)) @@ -257,8 +243,8 @@ def gaussian_q_sample( ) -> Tensor: if noise is None: noise = torch.randn_like(x_start) - if not (noise.shape == x_start.shape): - raise AssertionError + if noise.shape != x_start.shape: + raise ValueError("noise.shape != x_start.shape") return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise @@ -267,8 +253,8 @@ def gaussian_q_sample( def gaussian_q_posterior_mean_variance( self, x_start: Tensor, x_t: Tensor, t: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: - if not (x_start.shape == x_t.shape): - raise AssertionError + if x_start.shape != x_t.shape: + raise ValueError("x_start.shape != x_t.shape") posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t @@ -283,7 +269,7 @@ def gaussian_q_posterior_mean_variance( == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ): - raise AssertionError + raise ValueError("tensor lengths mismatch") return posterior_mean, posterior_variance, posterior_log_variance_clipped def gaussian_p_mean_variance( @@ -291,16 +277,14 @@ def gaussian_p_mean_variance( model_output: Tensor, x: Tensor, t: Tensor, - clip_denoised: bool = False, - denoised_fn: Optional[nn.Module] = None, - model_kwargs: Any = None, + model_kwargs: Optional[dict] = None, ) -> dict: if model_kwargs is None: model_kwargs = {} B, C = x.shape[:2] - if not (t.shape == (B,)): - raise AssertionError + if t.shape != (B,): + raise ValueError("length of t is not equal to batch size") model_variance = torch.cat( [ @@ -329,8 +313,8 @@ def gaussian_p_mean_variance( if not ( model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape ): - raise AssertionError( - f"{model_mean.shape}, {model_log_variance.shape}, {pred_xstart.shape}, {x.shape}" + raise ValueError( + "not all of model_mean, model_log_variance, pred_xstart, x have the same shape" ) return { @@ -346,7 +330,6 @@ def _vb_terms_bpd( x_start: Tensor, x_t: Tensor, t: Tensor, - clip_denoised: bool = False, model_kwargs: Optional[dict] = None, ) -> dict: ( @@ -355,7 +338,7 @@ def _vb_terms_bpd( true_log_variance_clipped, ) = self.gaussian_q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) out = self.gaussian_p_mean_variance( - model_output, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + model_output, x_t, t, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] @@ -365,8 +348,8 @@ def _vb_terms_bpd( decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) - if not (decoder_nll.shape == x_start.shape): - raise AssertionError + if decoder_nll.shape != x_start.shape: + raise ValueError("decoder_nll.shape != x_start.shape") decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, @@ -404,7 +387,7 @@ def _gaussian_loss( x_t: Tensor, t: Tensor, noise: Tensor, - model_kwargs: Any = None, + model_kwargs: Optional[dict] = None, ) -> Tensor: if model_kwargs is None: model_kwargs = {} @@ -418,7 +401,6 @@ def _gaussian_loss( x_start=x_start, x_t=x_t, t=t, - clip_denoised=False, model_kwargs=model_kwargs, )["output"] @@ -427,8 +409,8 @@ def _gaussian_loss( def _predict_xstart_from_eps( self, x_t: Tensor, t: Tensor, eps: Tensor = 1e-08 ) -> Tensor: - if not (x_t.shape == eps.shape): - raise AssertionError + if x_t.shape != eps.shape: + raise ValueError("x_t.shape != eps.shape") return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps @@ -446,16 +428,12 @@ def gaussian_p_sample( model_out: Tensor, x: Tensor, t: Tensor, - clip_denoised: bool = False, - denoised_fn: Any = None, - model_kwargs: Any = None, + model_kwargs: Optional[dict] = None, ) -> dict: out = self.gaussian_p_mean_variance( model_out, x, t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = torch.randn_like(x) @@ -500,10 +478,14 @@ def q_pred(self, log_x_start: Tensor, t: Tensor) -> Tensor: return log_probs def predict_start(self, model_out: Tensor, log_x_t: Tensor) -> Tensor: - if not (model_out.size(0) == log_x_t.size(0)): - raise AssertionError - if not (model_out.size(1) == self.num_classes.sum()): - raise AssertionError(f"{model_out.size()}") + if model_out.size(0) != log_x_t.size(0): + raise ValueError( + f"length of model_out {model_out.size(0)} != length of log_x_t {log_x_t.size(0)}" + ) + if model_out.size(1) != self.num_classes.sum(): + raise ValueError( + f"length of model_out {model_out.size(1)} != total num_classes {self.num_classes.sum()}" + ) log_pred = torch.empty_like(model_out) for ix in self.slices_for_classes: @@ -524,8 +506,6 @@ def q_posterior(self, log_x_start: Tensor, log_x_t: Tensor, t: Tensor) -> Tensor t_broadcast == 0, log_x_start, log_EV_qxtmin_x0.to(torch.float32) ) - # unnormed_logprobs = log_EV_qxtmin_x0 + - # log q_pred_one_timestep(x_t, t) # Note: _NOT_ x_tmin1, which is how the formula is typically used!!! # Not very easy to see why this is true. But it is :) unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) @@ -693,9 +673,6 @@ def mixed_loss(self, x: Tensor, cond: Optional[Tensor] = None) -> tuple: if x_num.shape[1] > 0: loss_gauss = self._gaussian_loss(model_out_num, x_num, x_num_t, t, noise) - # loss_multi = torch.where(out_dict['y'] == 1, loss_multi, 2 * loss_multi) - # loss_gauss = torch.where(out_dict['y'] == 1, loss_gauss, 2 * loss_gauss) - return loss_multi.mean(), loss_gauss.mean() @torch.no_grad() @@ -712,7 +689,7 @@ def mixed_elbo(self, x0: Tensor, cond: Optional[Tensor] = None) -> dict: gaussian_loss = [] xstart_mse = [] mse = [] - # mu_mse = [] + mu_mse = [] out_mean = [] true_mean = [] multinomial_loss = [] @@ -747,13 +724,12 @@ def mixed_elbo(self, x0: Tensor, cond: Optional[Tensor] = None) -> dict: x_start=x_num, x_t=x_num_t, t=t_array, - clip_denoised=False, ) multinomial_loss.append(kl) gaussian_loss.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_num) ** 2)) - # mu_mse.append(mean_flat(out["mean_mse"])) + mu_mse.append(mean_flat(out["mean_mse"])) out_mean.append(mean_flat(out["out_mean"])) true_mean.append(mean_flat(out["true_mean"])) @@ -764,7 +740,7 @@ def mixed_elbo(self, x0: Tensor, cond: Optional[Tensor] = None) -> dict: multinomial_loss = torch.stack(multinomial_loss, dim=1) xstart_mse = torch.stack(xstart_mse, dim=1) mse = torch.stack(mse, dim=1) - # mu_mse = torch.stack(mu_mse, dim=1) + mu_mse = torch.stack(mu_mse, dim=1) out_mean = torch.stack(out_mean, dim=1) true_mean = torch.stack(true_mean, dim=1) @@ -783,7 +759,7 @@ def mixed_elbo(self, x0: Tensor, cond: Optional[Tensor] = None) -> dict: "losses_multinimial": multinomial_loss, "xstart_mse": xstart_mse, "mse": mse, - # "mu_mse": mu_mse + "mu_mse": mu_mse, "out_mean": out_mean, "true_mean": true_mean, } @@ -794,16 +770,12 @@ def gaussian_ddim_step( model_out_num: Tensor, x: Tensor, t: Tensor, - clip_denoised: bool = False, - denoised_fn: Any = None, eta: float = 0.0, ) -> Tensor: out = self.gaussian_p_mean_variance( model_out_num, x, t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, model_kwargs=None, ) @@ -811,7 +783,7 @@ def gaussian_ddim_step( alpha_bar = extract(self.alphas_cumprod, t, x.shape) alpha_bar_prev = extract(self.alphas_cumprod_prev, t, x.shape) - sigma = ( + sigma = eta or ( eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_bar / alpha_bar_prev) @@ -829,39 +801,14 @@ def gaussian_ddim_step( return sample - # @torch.no_grad() - # def gaussian_ddim_sample(self, noise, T, cond=None, eta=0.0): - # x = noise - # b = x.shape[0] - # device = x.device - # for t in reversed(range(T)): - # self.print(f"Sample timestep {t:4d}", end="\r") - # t_array = (torch.ones(b, device=device) * t).long() - # out_num = self.denoise_fn(x, t_array, y=cond) - # x = self.gaussian_ddim_step(out_num, x, t_array) - # self.print() - # return x - @torch.no_grad() def gaussian_ddim_reverse_step( self, model_out_num: Tensor, x: Tensor, t: Tensor, - clip_denoised: bool = False, - denoised_fn: Any = None, - eta: float = 0.0, ) -> Tensor: - if not (eta == 0.0): - raise AssertionError("Eta must be zero.") - out = self.gaussian_p_mean_variance( - model_out_num, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=None, - ) + out = self.gaussian_p_mean_variance(model_out_num, x, t) eps = ( extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] @@ -875,18 +822,6 @@ def gaussian_ddim_reverse_step( return mean_pred - # @torch.no_grad() - # def gaussian_ddim_reverse_sample(self, x, T, cond=None): - # b = x.shape[0] - # device = x.device - # for t in range(T): - # self.print(f"Reverse timestep {t:4d}", end="\r") - # t_array = (torch.ones(b, device=device) * t).long() - # out_num = self.denoise_fn(x, t_array, y=cond) - # x = self.gaussian_ddim_reverse_step(out_num, x, t_array, eta=0.0) - # self.print() - # return x - @torch.no_grad() def multinomial_ddim_step( self, model_out_cat: Tensor, log_x_t: Tensor, t: Tensor, eta: float = 0.0 @@ -895,7 +830,7 @@ def multinomial_ddim_step( alpha_bar = extract(self.alphas_cumprod, t, log_x_t.shape) alpha_bar_prev = extract(self.alphas_cumprod_prev, t, log_x_t.shape) - sigma = ( + sigma = eta or ( eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_bar / alpha_bar_prev) @@ -935,20 +870,17 @@ def sample_ddim(self, num_samples: int, cond: Any = None) -> Tensor: log_z = self.log_sample_categorical(uniform_logits) for i in reversed(range(0, self.num_timesteps)): - self.print(f"Sample timestep {i:4d}", end="\r") + debug(f"Sample timestep {i:4d}", end="\r") t = torch.full((b,), i, device=device, dtype=torch.long) model_out = self.denoise_fn( torch.cat([z_norm, log_z], dim=1).float(), t, y=cond ) model_out_num = model_out[:, : self.num_numerics] model_out_cat = model_out[:, self.num_numerics :] - z_norm = self.gaussian_ddim_step( - model_out_num, z_norm, t, clip_denoised=False - ) + z_norm = self.gaussian_ddim_step(model_out_num, z_norm, t) if has_cat: log_z = self.multinomial_ddim_step(model_out_cat, log_z, t) - self.print() z_ohe = torch.exp(log_z).round() z_cat = log_z if has_cat: @@ -970,27 +902,18 @@ def sample(self, num_samples: int, cond: Any = None) -> Tensor: ) log_z = self.log_sample_categorical(uniform_logits) - # y = torch.multinomial( - # cond, - # num_samples=b, - # replacement=True - # ) - # out_dict = {'y': y.long().to(device)} for i in reversed(range(0, self.num_timesteps)): - self.print(f"Sample timestep {i:4d}", end="\r") + debug(f"Sample timestep {i:4d}", end="\r") t = torch.full((b,), i, device=device, dtype=torch.long) model_out = self.denoise_fn( torch.cat([z_norm, log_z], dim=1).float(), t, y=cond ) model_out_num = model_out[:, : self.num_numerics] model_out_cat = model_out[:, self.num_numerics :] - z_norm = self.gaussian_p_sample( - model_out_num, z_norm, t, clip_denoised=False - )["sample"] + z_norm = self.gaussian_p_sample(model_out_num, z_norm, t)["sample"] if has_cat: log_z = self.p_sample(model_out_cat, log_z, t=t) - self.print() z_ohe = torch.exp(log_z).round() z_cat = log_z if has_cat: @@ -1006,7 +929,7 @@ def sample_all( ddim: bool = False, ) -> Tensor: if ddim: - self.print("Sample using DDIM.") + info("Sample using DDIM.") sample_fn = self.sample_ddim else: sample_fn = self.sample @@ -1017,7 +940,7 @@ def sample_all( for b in bs: sample = sample_fn(b, cond) if torch.any(sample.isnan()).item(): - raise FoundNANsError + raise ValueError("found NaNs in sample") all_samples.append(sample) return torch.cat(all_samples, dim=0) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index 8caad49f..297c01bf 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -3,513 +3,115 @@ """ # stdlib import math -from typing import Any, Callable, List, Optional, Union +from typing import Optional, Union # third party import torch -import torch.nn as nn -import torch.nn.functional as F import torch.optim -from torch import Tensor +from torch import Tensor, nn -ModuleType = Union[str, Callable[..., nn.Module]] +# synthcity absolute +from synthcity.plugins.core.models.mlp import MLP, get_nonlin -class SiLU(nn.Module): - def forward(self, x: Tensor) -> Tensor: - return x * torch.sigmoid(x) - - -def timestep_embedding(timesteps: Tensor, dim: int, max_period: int = 10000) -> Tensor: - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - -def reglu(x: Tensor) -> Tensor: - """The ReGLU activation function from [1]. - References: - [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 - """ - if not (x.shape[-1] % 2 == 0): - raise AssertionError - a, b = x.chunk(2, dim=-1) - return a * F.relu(b) - - -def geglu(x: Tensor) -> Tensor: - """The GEGLU activation function from [1]. - References: - [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 - """ - if not (x.shape[-1] % 2 == 0): - raise AssertionError - a, b = x.chunk(2, dim=-1) - return a * F.gelu(b) - - -class ReGLU(nn.Module): - """The ReGLU activation function from [shazeer2020glu]. - - Examples: - .. testcode:: - - module = ReGLU() - x = torch.randn(3, 4) - assert module(x).shape == (3, 2) - - References: - * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020 - """ - - def forward(self, x: Tensor) -> Tensor: - return reglu(x) - - -class GEGLU(nn.Module): - """The GEGLU activation function from [shazeer2020glu]. - - Examples: - .. testcode:: - - module = GEGLU() - x = torch.randn(3, 4) - assert module(x).shape == (3, 2) - - References: - * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020 - """ - - def forward(self, x: Tensor) -> Tensor: - return geglu(x) - - -def _make_nn_module(module_type: ModuleType, *args: Any) -> nn.Module: - return ( - ( - ReGLU() - if module_type == "ReGLU" - else GEGLU() - if module_type == "GEGLU" - else getattr(nn, module_type)(*args) - ) - if isinstance(module_type, str) - else module_type(*args) - ) - - -class MLP(nn.Module): - """The MLP model used in [gorishniy2021revisiting]. - - The following scheme describes the architecture: - - .. code-block:: text - - MLP: (in) -> Block -> ... -> Block -> Linear -> (out) - Block: (in) -> Linear -> Activation -> Dropout -> (out) - - Examples: - .. testcode:: - - x = torch.randn(4, 2) - module = MLP.make_baseline(x.shape[1], [3, 5], 0.1, 1) - assert module(x).shape == (len(x), 1) - - References: - * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 - """ - - class Block(nn.Module): - """The main building block of `MLP`.""" - - def __init__( - self, - *, - d_in: int, - d_out: int, - bias: bool, - activation: ModuleType, - dropout: float, - ) -> None: - super().__init__() - self.linear = nn.Linear(d_in, d_out, bias) - self.activation = _make_nn_module(activation) - self.dropout = nn.Dropout(dropout) - - def forward(self, x: Tensor) -> Tensor: - return self.dropout(self.activation(self.linear(x))) - +class TimeStepEmbedding(nn.Module): def __init__( self, - *, - d_in: int, - d_layers: List[int], - dropouts: Union[float, List[float]], - activation: Union[str, Callable[[], nn.Module]], - d_out: int, + dim: int, + max_period: int = 10000, + n_layers: int = 2, + nonlin: Union[str, nn.Module] = "silu", ) -> None: """ - Note: - `make_baseline` is the recommended constructor. - """ - super().__init__() - if isinstance(dropouts, float): - dropouts = [dropouts] * len(d_layers) - if not (len(d_layers) == len(dropouts)): - raise AssertionError - if activation in ["ReGLU", "GEGLU"]: - raise AssertionError - - self.blocks = nn.ModuleList( - [ - MLP.Block( - d_in=d_layers[i - 1] if i else d_in, - d_out=d, - bias=True, - activation=activation, - dropout=dropout, - ) - for i, (d, dropout) in enumerate(zip(d_layers, dropouts)) - ] - ) - self.head = nn.Linear(d_layers[-1] if d_layers else d_in, d_out) - - @classmethod - def make_baseline( - cls, - d_in: int, - d_layers: List[int], - dropout: float, - d_out: int, - ) -> "MLP": - """Create a "baseline" `MLP`. - - This variation of MLP was used in [gorishniy2021revisiting]. Features: - - * :code:`Activation` = :code:`ReLU` - * all linear layers except for the first one and the last one are of the same dimension - * the dropout rate is the same for all dropout layers + Create sinusoidal timestep embeddings. Args: - d_in: the input size - d_layers: the dimensions of the linear layers. If there are more than two - layers, then all of them except for the first and the last ones must - have the same dimension. Valid examples: :code:`[]`, :code:`[8]`, - :code:`[8, 16]`, :code:`[2, 2, 2, 2]`, :code:`[1, 2, 2, 4]`. Invalid - example: :code:`[1, 2, 3, 4]`. - dropout: the dropout rate for all hidden layers - d_out: the output size - Returns: - MLP - - References: - * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 + - dim (int): the dimension of the output. + - max_period (int): controls the minimum frequency of the embeddings. + - n_layers (int): number of dense layers """ - if not (isinstance(dropout, float)): - raise AssertionError - if len(d_layers) > 2: - if not len(set(d_layers[1:-1])) == 1: - raise AssertionError( - "if d_layers contains more than two elements, then" - " all elements except for the first and the last ones must be equal." - ) - return MLP( - d_in=d_in, - d_layers=d_layers, - dropouts=dropout, - activation="ReLU", - d_out=d_out, - ) - - def forward(self, x: Tensor) -> Tensor: - x = x.float() - for block in self.blocks: - x = block(x) - x = self.head(x) - return x - - -class ResNet(nn.Module): - """The ResNet model used in [gorishniy2021revisiting]. - The following scheme describes the architecture: - .. code-block:: text - ResNet: (in) -> Linear -> Block -> ... -> Block -> Head -> (out) - |-> Norm -> Linear -> Activation -> Dropout -> Linear -> Dropout ->| - | | - Block: (in) ------------------------------------------------------------> Add -> (out) - Head: (in) -> Norm -> Activation -> Linear -> (out) - Examples: - .. testcode:: - x = torch.randn(4, 2) - module = ResNet.make_baseline( - d_in=x.shape[1], - n_blocks=2, - d_main=3, - d_hidden=4, - dropout_first=0.25, - dropout_second=0.0, - d_out=1 - ) - assert module(x).shape == (len(x), 1) - References: - * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 - """ - - class Block(nn.Module): - """The main building block of `ResNet`.""" - - def __init__( - self, - *, - d_main: int, - d_hidden: int, - bias_first: bool, - bias_second: bool, - dropout_first: float, - dropout_second: float, - normalization: ModuleType, - activation: ModuleType, - skip_connection: bool, - ) -> None: - super().__init__() - self.normalization = _make_nn_module(normalization, d_main) - self.linear_first = nn.Linear(d_main, d_hidden, bias_first) - self.activation = _make_nn_module(activation) - self.dropout_first = nn.Dropout(dropout_first) - self.linear_second = nn.Linear(d_hidden, d_main, bias_second) - self.dropout_second = nn.Dropout(dropout_second) - self.skip_connection = skip_connection - - def forward(self, x: Tensor) -> Tensor: - x_input = x - x = self.normalization(x) - x = self.linear_first(x) - x = self.activation(x) - x = self.dropout_first(x) - x = self.linear_second(x) - x = self.dropout_second(x) - if self.skip_connection: - x = x_input + x - return x + super().__init__() + self.dim = dim + self.max_period = max_period + self.n_layers = n_layers - class Head(nn.Module): - """The final module of `ResNet`.""" + if dim % 2 != 0: + raise ValueError(f"embedding dim must be even, got {dim}") - def __init__( - self, - *, - d_in: int, - d_out: int, - bias: bool, - normalization: ModuleType, - activation: ModuleType, - ) -> None: - super().__init__() - self.normalization = _make_nn_module(normalization, d_in) - self.activation = _make_nn_module(activation) - self.linear = nn.Linear(d_in, d_out, bias) + layers = [] + for _ in range(n_layers - 1): + layers.append(nn.Linear(dim, dim)) + layers.append(get_nonlin(nonlin)) - def forward(self, x: Tensor) -> Tensor: - if self.normalization is not None: - x = self.normalization(x) - x = self.activation(x) - x = self.linear(x) - return x + self.fc = nn.Sequential(*layers, nn.Linear(dim, dim)) - def __init__( - self, - *, - d_in: int, - n_blocks: int, - d_main: Optional[int], - d_hidden: int, - dropout_first: float, - dropout_second: float, - normalization: ModuleType, - activation: ModuleType, - d_out: int, - ) -> None: - """ - Note: - `make_baseline` is the recommended constructor. + def forward(self, timesteps: Tensor) -> Tensor: """ - super().__init__() - - self.first_layer = nn.Linear(d_in, d_main) - if d_main is None: - d_main = d_in - self.blocks = nn.Sequential( - *[ - ResNet.Block( - d_main=d_main, - d_hidden=d_hidden, - bias_first=True, - bias_second=True, - dropout_first=dropout_first, - dropout_second=dropout_second, - normalization=normalization, - activation=activation, - skip_connection=True, - ) - for _ in range(n_blocks) - ] - ) - self.head = ResNet.Head( - d_in=d_main, - d_out=d_out, - bias=True, - normalization=normalization, - activation=activation, - ) - - @classmethod - def make_baseline( - cls, - *, - d_in: int, - n_blocks: int, - d_main: int, - d_hidden: int, - dropout_first: float, - dropout_second: float, - d_out: int, - ) -> "ResNet": - """Create a "baseline" `ResNet`. - This variation of ResNet was used in [gorishniy2021revisiting]. Features: - * :code:`Activation` = :code:`ReLU` - * :code:`Norm` = :code:`BatchNorm1d` Args: - d_in: the input size - n_blocks: the number of Blocks - d_main: the input size (or, equivalently, the output size) of each Block - d_hidden: the output size of the first linear layer in each Block - dropout_first: the dropout rate of the first dropout layer in each Block. - dropout_second: the dropout rate of the second dropout layer in each Block. - References: - * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 + - timesteps (Tensor): 1D Tensor of N indices, one per batch element. """ - return cls( - d_in=d_in, - n_blocks=n_blocks, - d_main=d_main, - d_hidden=d_hidden, - dropout_first=dropout_first, - dropout_second=dropout_second, - normalization="BatchNorm1d", - activation="ReLU", - d_out=d_out, - ) - - def forward(self, x: Tensor) -> Tensor: - x = x.float() - x = self.first_layer(x) - x = self.blocks(x) - x = self.head(x) - return x - - -# **For diffusion** + d, T = self.dim, self.max_period + mid = d // 2 + fs = torch.exp(-math.log(T) / mid * torch.arange(mid, dtype=torch.float32)) + args = timesteps[:, None].float() * fs[None] + emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + return self.fc(emb) class MLPDiffusion(nn.Module): + add_residual = False + def __init__( self, - d_in: int, - num_classes: int, - is_y_cond: bool, - rtdl_params: dict, - dim_t: int = 128, + dim_in: int, + dim_emb: int = 128, + *, + mlp_params: dict = {}, + use_label: bool = False, + num_classes: int = 0, + emb_nonlin: Union[str, nn.Module] = "silu", + max_time_period: int = 10000, ) -> None: super().__init__() - self.dim_t = dim_t + self.dim_t = dim_emb self.num_classes = num_classes - self.is_y_cond = is_y_cond - - # d0 = rtdl_params['d_layers'][0] - - rtdl_params["d_in"] = dim_t - rtdl_params["d_out"] = d_in + self.has_label = use_label - self.mlp = MLP.make_baseline(**rtdl_params) + if isinstance(emb_nonlin, str): + self.emb_nonlin = get_nonlin(emb_nonlin) + else: + self.emb_nonlin = emb_nonlin - if self.num_classes > 0 and is_y_cond: - self.label_emb = nn.Embedding(self.num_classes, dim_t) - elif self.num_classes == 0 and is_y_cond: - self.label_emb = nn.Linear(1, dim_t) + self.proj = nn.Linear(dim_in, dim_emb) + self.time_emb = TimeStepEmbedding(dim_emb, max_time_period) - self.proj = nn.Linear(d_in, dim_t) - self.time_embed = nn.Sequential( - nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t) + if use_label: + if self.num_classes > 0: + self.label_emb = nn.Embedding(self.num_classes, dim_emb) + elif self.num_classes == 0: # regression + self.label_emb = nn.Linear(1, dim_emb) + + self.model = MLP( + n_units_in=dim_emb, + n_units_out=dim_in, + task_type="/", + residual=self.add_residual, + **mlp_params, ) - def forward( - self, x: Tensor, timesteps: Tensor, y: Optional[Tensor] = None - ) -> Tensor: - emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) - if self.is_y_cond and y is not None: - if self.num_classes > 0: - y = y.squeeze() + def forward(self, x: Tensor, t: Tensor, y: Optional[Tensor] = None) -> Tensor: + emb = self.time_emb(t) + if self.has_label: + if y is None: + raise ValueError("y must be provided if use_label is True") + if self.num_classes == 0: + y = y.resize(-1, 1).float() else: - y = y.resize(y.size(0), 1).float() - emb += F.silu(self.label_emb(y)) + y = y.squeeze().long() + emb += self.emb_nonlin(self.label_emb(y)) x = self.proj(x) + emb - return self.mlp(x) + return self.model(x) -class ResNetDiffusion(nn.Module): - def __init__( - self, - d_in: int, - num_classes: int, - is_y_cond: bool, - rtdl_params: dict, - dim_t: int = 256, - ) -> None: - super().__init__() - self.dim_t = dim_t - self.num_classes = num_classes - - rtdl_params["d_in"] = d_in - rtdl_params["d_out"] = d_in - rtdl_params["emb_d"] = dim_t - self.resnet = ResNet.make_baseline(**rtdl_params) - - if self.num_classes > 0 and is_y_cond: - self.label_emb = nn.Embedding(self.num_classes, dim_t) - elif self.num_classes == 0 and is_y_cond: - self.label_emb = nn.Linear(1, dim_t) - - self.proj = nn.Linear(d_in, dim_t) - self.time_embed = nn.Sequential( - nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t) - ) - - def forward( - self, x: Tensor, timesteps: Tensor, y: Optional[Tensor] = None - ) -> Tensor: - emb = self.time_embed(timestep_embedding(timesteps, self.dim_t)) - if self.is_y_cond and y is not None: - if self.num_classes > 0: - y = y.squeeze() - else: - y = y.resize(y.size(0), 1).float() - emb += F.silu(self.label_emb(y)) - x = self.proj(x) + emb - return self.resnet(x) +class ResNetDiffusion(MLPDiffusion): + add_residual = True diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index 4d4c92bd..8ec0c025 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -18,13 +18,12 @@ def normal_kl(mean1: Tensor, logvar1: Tensor, mean2: Tensor, logvar2: Tensor) -> Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, Tensor): - tensor = obj - break - if tensor is None: - raise AssertionError("at least one argument must be a Tensor") + try: + tensor = next( + x for x in (mean1, logvar1, mean2, logvar2) if isinstance(x, Tensor) + ) + except StopIteration: + raise TypeError("at least one argument must be a Tensor") # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). @@ -66,7 +65,7 @@ def discretized_gaussian_log_likelihood( :return: a tensor like x of log probabilities (in nats). """ if not (x.shape == means.shape == log_scales.shape): - raise AssertionError + raise ValueError("shapes must match") centered_x = x - means inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1.0 / 255.0) @@ -83,8 +82,8 @@ def discretized_gaussian_log_likelihood( x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)) ), ) - if not (log_probs.shape == x.shape): - raise AssertionError + if log_probs.shape != x.shape: + raise ValueError("shapes must match") return log_probs @@ -123,8 +122,9 @@ def log_1_min_a(a: Tensor) -> Tensor: def log_add_exp(a: Tensor, b: Tensor) -> Tensor: - maximum = torch.max(a, b) - return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) + """Numerically stable log(exp(a) + exp(b)).""" + m = torch.max(a, b) + return m + torch.log(torch.exp(a - m) + torch.exp(b - m)) def extract(a: Tensor, t: Tensor, x_shape: tuple) -> Tensor: @@ -171,13 +171,6 @@ def sliced_logsumexp(x: Tensor, slices: Tensor) -> Tensor: return slice_lse_repeated -class FoundNANsError(BaseException): - """Found NANs during sampling""" - - def __init__(self, message: str = "Found NANs during sampling.") -> None: - super(FoundNANsError, self).__init__(message) - - class TensorDataLoader: """ A DataLoader-like object for a set of tensors that can be much faster than @@ -198,13 +191,13 @@ def __init__( :returns: A FastTensorDataLoader. """ if not all(t.shape[0] == tensors[0].shape[0] for t in tensors): - raise AssertionError + raise ValueError("All tensors must have the same length.") self.tensors = tensors self.dataset_len = self.tensors[0].shape[0] self.batch_size = batch_size self.shuffle = shuffle - def __iter__(self) -> Iterator[tuple]: + def __iter__(self) -> Iterator: idx = np.arange(self.dataset_len) if self.shuffle: np.random.shuffle(idx) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 6556ec05..b149e336 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -4,7 +4,7 @@ # stdlib from pathlib import Path -from typing import Any, List +from typing import Any, List, Sequence # third party import numpy as np @@ -19,6 +19,7 @@ from synthcity.plugins.core.models.tabular_ddpm import TabDDPM from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema +from synthcity.utils.callbacks import Callback from synthcity.utils.constants import DEVICE @@ -31,14 +32,55 @@ class TabDDPMPlugin(Plugin): Tabular denoising diffusion probabilistic model. Args: - ... + is_classification: bool = False + Whether the task is classification or regression. + n_iter: int = 1000 + Number of epochs for training. + lr: float = 0.002 + Learning rate. + weight_decay: float = 1e-4 + L2 weight decay. + batch_size: int = 1024 + Size of mini-batches. + model_type: str = "mlp" + Type of model to use. Either "mlp" or "resnet". + num_timesteps: int = 1000 + Number of timesteps to use in the diffusion process. + gaussian_loss_type: str = "mse" + Type of loss to use for the Gaussian diffusion process. Either "mse" or "kl". + scheduler: str = "cosine" + The scheduler of forward process variance 'beta' to use. Either "cosine" or "linear". + device: Any = DEVICE + Device to use for training. + callbacks: Sequence[Callback] = () + Callbacks to use during training. + log_interval: int = 100 + Number of iterations between logging. + print_interval: int = 500 + Number of iterations between printing. + n_layers_hidden: int = 3 + Number of hidden layers in the MLP. + dim_hidden: int = 256 + Number of hidden units per hidden layer in the MLP. + dropout: float = 0.0 + Dropout rate. + dim_embed: int = 128 + Dimensionality of the embedding space. + random_state: int + random seed to use + workspace: Path. + Optional Path for caching intermediary results. + compress_dataset: bool. Default = False. + Drop redundant features before training the generator. + sampling_patience: int. + Max inference iterations to wait for the generated data to match the training schema. Example: >>> from sklearn.datasets import load_iris >>> from synthcity.plugins import Plugins - >>> X, y = load_iris(as_frame = True, return_X_y = True) + >>> X, y = load_iris(as_frame=True, return_X_y=True) >>> X["target"] = y - >>> plugin = Plugins().get("ddpm", n_iter = 100) + >>> plugin = Plugins().get("ddpm", n_iter=100, is_classification=True) >>> plugin.fit(X) >>> plugin.generate(50) @@ -58,19 +100,14 @@ def __init__( gaussian_loss_type: str = "mse", scheduler: str = "cosine", device: Any = DEVICE, - verbose: int = 0, + callbacks: Sequence[Callback] = (), log_interval: int = 100, print_interval: int = 500, # model params - num_layers: int = 3, + n_layers_hidden: int = 3, dim_hidden: int = 256, dropout: float = 0.0, - dim_label_emb: int = 128, - # early stopping - n_iter_min: int = 100, - n_iter_print: int = 50, - patience: int = 5, - # patience_metric: Optional[WeightedMetrics] = None, + dim_embed: int = 128, # core plugin arguments random_state: int = 0, workspace: Path = Path("workspace"), @@ -89,7 +126,10 @@ def __init__( self.is_classification = is_classification - rtdl_params = dict(d_layers=[dim_hidden] * num_layers, dropout=dropout) + mlp_params = dict( + n_layers_hidden=n_layers_hidden, n_units_hidden=dim_hidden, dropout=dropout + ) + self.model = TabDDPM( n_iter=n_iter, lr=lr, @@ -97,17 +137,15 @@ def __init__( batch_size=batch_size, num_timesteps=num_timesteps, gaussian_loss_type=gaussian_loss_type, + is_classification=is_classification, scheduler=scheduler, device=device, - verbose=verbose, + callbacks=callbacks, log_interval=log_interval, print_interval=print_interval, model_type=model_type, - rtdl_params=rtdl_params, - dim_label_emb=dim_label_emb, - n_iter_min=n_iter_min, - n_iter_print=n_iter_print, - patience=patience, + mlp_params=mlp_params, + dim_embed=dim_embed, ) @staticmethod @@ -141,20 +179,24 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]: CategoricalDistribution(name="batch_size", choices=[256, 4096]), CategoricalDistribution(name="num_timesteps", choices=[100, 1000]), CategoricalDistribution(name="n_iter", choices=[5000, 10000, 20000]), - CategoricalDistribution(name="num_layers", choices=[2, 4, 6, 8]), + CategoricalDistribution(name="n_layers_hidden", choices=[2, 4, 6, 8]), CategoricalDistribution(name="dim_hidden", choices=[128, 256, 512, 1024]), ] def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": - cond = None + """Fit the model to the data. + + Optionally, a condition can be given as the keyword argument `cond`. + + If the task is classification, the target labels are automatically regarded as the condition, and no additional condition should be given. + + If the task is regression, the target variable is not specially treated. There is no condition by default, but can be given by the user, either as a column name or an array-like. + """ + df = X.dataframe() + cond = kwargs.pop("cond", None) + if args: - if len(args) > 1: - raise ValueError("Only one positional argument is allowed") - if "cond" in kwargs: - raise ValueError("cond is already given by the positional argument") - cond = args[0] - elif "cond" in kwargs: - cond = kwargs.pop("cond") + raise ValueError("Only keyword arguments are allowed") if self.is_classification: if cond is not None: @@ -164,15 +206,14 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": _, cond = X.unpack() self._labels, self._cond_dist = np.unique(cond, return_counts=True) self._cond_dist = self._cond_dist / self._cond_dist.sum() - - # NOTE: should we include the target column in `df`? - df = X.dataframe() + else: + if type(cond) is str: + cond = df[cond] if cond is not None: cond = pd.Series(cond, index=df.index) - # self.encoder = TabularEncoder().fit(X) - + # NOTE: cond may also be included in the dataframe self.model.fit(df, cond, **kwargs) return self @@ -184,7 +225,7 @@ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader # randomly generate labels following the distribution of the training data cond = np.random.choice(self._labels, size=count, p=self._cond_dist) - def callback(count, cond=cond): # type: ignore + def callback(count): # type: ignore return self.model.generate(count, cond=cond) return self._safe_generate(callback, count, syn_schema, **kwargs) diff --git a/src/synthcity/utils/callbacks.py b/src/synthcity/utils/callbacks.py new file mode 100644 index 00000000..fa54074c --- /dev/null +++ b/src/synthcity/utils/callbacks.py @@ -0,0 +1,91 @@ +# stdlib +from abc import ABC, abstractmethod +from typing import Optional + +# third party +import numpy as np +import pandas as pd +from torch import Tensor, nn + +# synthcity absolute +from synthcity.metrics.weighted_metrics import WeightedMetrics + + +class Callback(ABC): + """Abstract base class of callbacks.""" + + @abstractmethod + def on_epoch_begin(self, model: nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def on_epoch_end(self, model: nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def on_fit_begin(self, model: nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def on_fit_end(self, model: nn.Module) -> None: + raise NotImplementedError + + +class EarlyStopping(Callback): + def __init__( + self, + patience: int = 5, + min_epochs: int = 100, + patience_metric: Optional[WeightedMetrics] = None, + ) -> None: + self.patience = patience + self.patience_metric = patience_metric + self.min_epochs = min_epochs + self.best_score = self._init_patience_score() + self.best_model_state = None + self.wait = 0 + self._epochs = 0 + + def on_epoch_end(self, model: nn.Module) -> None: + self._epochs += 1 + if self.patience_metric is not None: + if not hasattr(self, "X_val"): + self.X_val = model.X_val + if isinstance(self.X_val, Tensor): + self.X_val = self.X_val.detach().cpu().numpy() + self._evaluate_patience_metric(model) + if self.wait >= self.patience and self._epochs >= self.min_epochs: + raise StopIteration("Early stopping") + + def on_fit_end(self, model: nn.Module) -> None: + if self.best_model_state is not None: + model.load_state_dict(self.best_model_state) # type: ignore + + def _init_patience_score(self) -> float: + if self.patience_metric is None: + return 0 + elif self.patience_metric.direction() == "minimize": + return np.inf + else: + return -np.inf + + def _evaluate_patience_metric(self, model: nn.Module) -> None: + X_val = self.X_val + X_syn = model.generate(len(X_val)) + + new_score = self.patience_metric.evaluate( # type: ignore + pd.DataFrame(X_val), + pd.DataFrame(X_syn), + ) + + if self.patience_metric.direction() == "minimize": # type: ignore + is_new_best = new_score < self.best_score + else: + is_new_best = new_score > self.best_score + + if is_new_best: + self.wait = 0 + self.best_score = new_score + self.best_model_state = model.state_dict() + else: + self.wait += 1 diff --git a/src/synthcity/utils/dataframe.py b/src/synthcity/utils/dataframe.py index c12b29da..a313b91e 100644 --- a/src/synthcity/utils/dataframe.py +++ b/src/synthcity/utils/dataframe.py @@ -6,7 +6,7 @@ def constant_columns(dataframe: pd.DataFrame) -> list: """ Find constant value columns in a pandas dataframe. """ - return discrete_columns(dataframe, 2) + return discrete_columns(dataframe, 1) def discrete_columns( @@ -19,5 +19,5 @@ def discrete_columns( (col, cnt) if return_counts else col for col, vals in dataframe.items() for cnt in [vals.nunique()] - if cnt < max_classes + if cnt <= max_classes ] diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index 7f56077a..cddadf62 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -15,16 +15,11 @@ plugin_name = "ddpm" plugin_args = dict( n_iter=1000, - # is_classification=True, + is_classification=True, batch_size=200, num_timesteps=500, - verbose=1, log_interval=10, - print_interval=100 - # rtdl_params=dict( - # d_layers=[256, 256], - # dropout=0.0 - # ) + print_interval=100, ) From a9438dc93f01d6c226c3865f3218fd7f60dc55e5 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 16 Mar 2023 13:14:31 +0100 Subject: [PATCH 25/95] remove TensorDataLoader, update test_ddpm --- .../core/models/tabular_ddpm/__init__.py | 19 +++-- .../gaussian_multinomial_diffsuion.py | 74 +++++++++++-------- .../plugins/core/models/tabular_ddpm/utils.py | 50 ++----------- src/synthcity/plugins/generic/plugin_ddpm.py | 5 ++ tests/plugins/generic/test_ddpm.py | 47 +++++++----- 5 files changed, 94 insertions(+), 101 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index c762f389..d80c2a85 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -9,6 +9,7 @@ import torch from pydantic import validate_arguments from torch import nn +from torch.utils.data import DataLoader, TensorDataset # synthcity absolute from synthcity.logger import info @@ -19,7 +20,6 @@ # synthcity relative from .gaussian_multinomial_diffsuion import GaussianMultinomialDiffusion -from .utils import TensorDataLoader class TabDDPM(nn.Module): @@ -104,14 +104,18 @@ def fit( dim_emb=self.dim_embed, ) - tensors = [ + dataset = TensorDataset( torch.tensor(X.values, dtype=torch.float32, device=self.device), - np.repeat(None, len(X)) + torch.tensor([torch.nan] * len(X), dtype=torch.float32, device=self.device) if cond is None - else torch.tensor(cond.values, dtype=torch.long, device=self.device), - ] + else torch.tensor( + cond.values, + dtype=torch.long if self.is_classification else torch.float32, + device=self.device, + ), + ) - self.dataloader = TensorDataLoader(*tensors, batch_size=self.batch_size) + self.dataloader = DataLoader(dataset, batch_size=self.batch_size) self.diffusion = GaussianMultinomialDiffusion( model_type=self.model_type, @@ -150,7 +154,8 @@ def fit( for x, y in self.dataloader: self.optimizer.zero_grad() - loss_multi, loss_gauss = self.diffusion.mixed_loss(x, y) + args = (x,) if cond is None else (x, y) + loss_multi, loss_gauss = self.diffusion.mixed_loss(*args) loss = loss_multi + loss_gauss loss.backward() self.optimizer.step() diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 25bc57f1..270d8b03 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -21,7 +21,6 @@ from .modules import MLPDiffusion, ResNetDiffusion from .utils import ( discretized_gaussian_log_likelihood, - extract, index_to_log_onehot, log_1_min_a, log_add_exp, @@ -29,6 +28,7 @@ mean_flat, normal_kl, ohe_to_categories, + perm_and_expand, sliced_logsumexp, sum_except_batch, ) @@ -233,9 +233,9 @@ def __init__( def gaussian_q_mean_variance( self, x_start: Tensor, t: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: - mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract(self.log_1_min_cumprod_alpha, t, x_start.shape) + mean = perm_and_expand(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = perm_and_expand(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = perm_and_expand(self.log_1_min_cumprod_alpha, t, x_start.shape) return mean, variance, log_variance def gaussian_q_sample( @@ -246,8 +246,9 @@ def gaussian_q_sample( if noise.shape != x_start.shape: raise ValueError("noise.shape != x_start.shape") return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + perm_and_expand(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + perm_and_expand(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise ) def gaussian_q_posterior_mean_variance( @@ -256,11 +257,11 @@ def gaussian_q_posterior_mean_variance( if x_start.shape != x_t.shape: raise ValueError("x_start.shape != x_t.shape") posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start - + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + perm_and_expand(self.posterior_mean_coef1, t, x_t.shape) * x_start + + perm_and_expand(self.posterior_mean_coef2, t, x_t.shape) * x_t ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract( + posterior_variance = perm_and_expand(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = perm_and_expand( self.posterior_log_variance_clipped, t, x_t.shape ) if not ( @@ -296,15 +297,15 @@ def gaussian_p_mean_variance( # model_variance = self.posterior_variance.to(x.device) model_log_variance = torch.log(model_variance) - model_variance = extract(model_variance, t, x.shape) - model_log_variance = extract(model_log_variance, t, x.shape) + model_variance = perm_and_expand(model_variance, t, x.shape) + model_log_variance = perm_and_expand(model_log_variance, t, x.shape) if self.gaussian_parametrization == "eps": pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) elif self.gaussian_parametrization == "x0": pred_xstart = model_output else: - raise NotImplementedError + raise ValueError("unknown gaussian_parametrization. Must be 'eps' or 'x0'") model_mean, _, _ = self.gaussian_q_posterior_mean_variance( x_start=pred_xstart, x_t=x, t=t @@ -412,16 +413,17 @@ def _predict_xstart_from_eps( if x_t.shape != eps.shape: raise ValueError("x_t.shape != eps.shape") return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + perm_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - perm_and_expand(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_eps_from_xstart( self, x_t: Tensor, t: Tensor, pred_xstart: Tensor ) -> Tensor: return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart - ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + perm_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / perm_and_expand(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def gaussian_p_sample( self, @@ -453,8 +455,8 @@ def multinomial_kl(self, log_prob1: Tensor, log_prob2: Tensor) -> Tensor: return kl def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor: - log_alpha_t = extract(self.log_alpha, t, log_x_t.shape) - log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape) + log_alpha_t = perm_and_expand(self.log_alpha, t, log_x_t.shape) + log_1_min_alpha_t = perm_and_expand(self.log_1_min_alpha, t, log_x_t.shape) # alpha_t * E[xt] + (1 - alpha_t) 1 / K log_probs = log_add_exp( @@ -465,8 +467,10 @@ def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor: return log_probs def q_pred(self, log_x_start: Tensor, t: Tensor) -> Tensor: - log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape) - log_1_min_cumprod_alpha = extract( + log_cumprod_alpha_t = perm_and_expand( + self.log_cumprod_alpha, t, log_x_start.shape + ) + log_1_min_cumprod_alpha = perm_and_expand( self.log_1_min_cumprod_alpha, t, log_x_start.shape ) @@ -525,7 +529,7 @@ def p_pred(self, model_out: Tensor, log_x: Tensor, t: Tensor) -> Tensor: elif self.parametrization == "direct": log_model_pred = self.predict_start(model_out, log_x) else: - raise ValueError + raise ValueError(f"unknown parametrization {self.parametrization}") return log_model_pred @torch.no_grad() @@ -613,8 +617,11 @@ def sample_time( pt = torch.ones_like(t).float() / self.num_timesteps return t, pt + else: - raise ValueError + raise ValueError( + "Unknown sampling method. Must be 'importance' or 'uniform'." + ) def _multinomial_loss( self, @@ -636,8 +643,11 @@ def _multinomial_loss( # Expensive, dont do it ;). # DEPRECATED return -self.nll(log_x_start) + else: - raise ValueError() + raise ValueError( + "Unknown multinomial loss type. Must be 'vb_stochastic' or 'vb_all'." + ) def mixed_loss(self, x: Tensor, cond: Optional[Tensor] = None) -> tuple: b = x.shape[0] @@ -665,6 +675,7 @@ def mixed_loss(self, x: Tensor, cond: Optional[Tensor] = None) -> tuple: loss_multi = torch.zeros((1,)).float() loss_gauss = torch.zeros((1,)).float() + if x_cat.shape[1] > 0: loss_multi = self._multinomial_loss( model_out_cat, log_x_cat, log_x_cat_t, t, pt @@ -781,8 +792,8 @@ def gaussian_ddim_step( eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) - alpha_bar = extract(self.alphas_cumprod, t, x.shape) - alpha_bar_prev = extract(self.alphas_cumprod_prev, t, x.shape) + alpha_bar = perm_and_expand(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = perm_and_expand(self.alphas_cumprod_prev, t, x.shape) sigma = eta or ( eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) @@ -811,9 +822,10 @@ def gaussian_ddim_reverse_step( out = self.gaussian_p_mean_variance(model_out_num, x, t) eps = ( - extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] - ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) - alpha_bar_next = extract(self.alphas_cumprod_next, t, x.shape) + perm_and_expand(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / perm_and_expand(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = perm_and_expand(self.alphas_cumprod_next, t, x.shape) mean_pred = ( out["pred_xstart"] * torch.sqrt(alpha_bar_next) @@ -828,8 +840,8 @@ def multinomial_ddim_step( ) -> Tensor: log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t) - alpha_bar = extract(self.alphas_cumprod, t, log_x_t.shape) - alpha_bar_prev = extract(self.alphas_cumprod_prev, t, log_x_t.shape) + alpha_bar = perm_and_expand(self.alphas_cumprod, t, log_x_t.shape) + alpha_bar_prev = perm_and_expand(self.alphas_cumprod_prev, t, log_x_t.shape) sigma = eta or ( eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index 8ec0c025..04eb9d8f 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -1,9 +1,6 @@ # future from __future__ import annotations -# stdlib -from typing import Iterator - # third party import numpy as np import torch @@ -127,10 +124,11 @@ def log_add_exp(a: Tensor, b: Tensor) -> Tensor: return m + torch.log(torch.exp(a - m) + torch.exp(b - m)) -def extract(a: Tensor, t: Tensor, x_shape: tuple) -> Tensor: - b, *_ = t.shape - t = t.to(a.device) - out = a.gather(-1, t) +def perm_and_expand(a: Tensor, t: Tensor, x_shape: tuple) -> Tensor: + """Permutes a tensor in the order specified by `t` and expands it to `x_shape`.""" + if not (a.ndim == 1 and t.shape == (x_shape[0],)): + raise ValueError(f"dimensionality mismatch: {a.shape}, {t.shape}, {x_shape}") + out = a[t] while len(out.shape) < len(x_shape): out = out[..., None] return out.expand(x_shape) @@ -169,41 +167,3 @@ def sliced_logsumexp(x: Tensor, slices: Tensor) -> Tensor: slice_lse, slice_ends - slice_starts, dim=-1 ) return slice_lse_repeated - - -class TensorDataLoader: - """ - A DataLoader-like object for a set of tensors that can be much faster than - TensorDataset + DataLoader because dataloader grabs individual indices of - the dataset and calls cat (slow). - Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 - """ - - def __init__( - self, *tensors: Tensor, batch_size: int = 32, shuffle: bool = False - ) -> None: - """ - Initialize a FastTensorDataLoader. - :param *tensors: tensors to store. Must have the same length @ dim 0. - :param batch_size: batch size to load. - :param shuffle: if True, shuffle the data *in-place* whenever an - iterator is created out of this object. - :returns: A FastTensorDataLoader. - """ - if not all(t.shape[0] == tensors[0].shape[0] for t in tensors): - raise ValueError("All tensors must have the same length.") - self.tensors = tensors - self.dataset_len = self.tensors[0].shape[0] - self.batch_size = batch_size - self.shuffle = shuffle - - def __iter__(self) -> Iterator: - idx = np.arange(self.dataset_len) - if self.shuffle: - np.random.shuffle(idx) - for i in range(0, self.dataset_len, self.batch_size): - s = idx[i : i + self.batch_size] - yield tuple(t[s] for t in self.tensors) - - def __len__(self) -> int: - return len(range(0, self.dataset_len, self.batch_size)) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index b149e336..631480fc 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -195,6 +195,11 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": df = X.dataframe() cond = kwargs.pop("cond", None) + # note that the TabularEncoder is not used in this plugin, because the + # Gaussian multinomial diffusion module needs to know the number of classes + # for each discrete feature before it applies torch.nn.functional.one_hot + # on these features, and it also preprocesses the continuous features differently. + if args: raise ValueError("Only keyword arguments are allowed") diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index cddadf62..2f9afeae 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -1,3 +1,7 @@ +# stdlib +from itertools import product +from typing import Any, Generator + # third party import numpy as np import pandas as pd @@ -13,9 +17,8 @@ from synthcity.plugins.generic.plugin_ddpm import plugin plugin_name = "ddpm" -plugin_args = dict( +plugin_params = dict( n_iter=1000, - is_classification=True, batch_size=200, num_timesteps=500, log_interval=10, @@ -23,29 +26,39 @@ ) -@pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) -) +def extend_fixtures( + plugin_name: str = plugin_name, + plugin: Any = plugin, + plugin_params: dict = plugin_params, + **extra_params: list +) -> Generator: + if not extra_params: + yield from generate_fixtures(plugin_name, plugin, plugin_params) + return + param_set = list(product(*extra_params.values())) + for values in param_set: + params = plugin_params.copy() + params.update(zip(extra_params.keys(), values)) + yield from generate_fixtures(plugin_name, plugin, params) + + +@pytest.mark.parametrize("test_plugin", extend_fixtures()) def test_plugin_sanity(test_plugin: Plugin) -> None: assert test_plugin is not None -@pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) -) +@pytest.mark.parametrize("test_plugin", extend_fixtures()) def test_plugin_name(test_plugin: Plugin) -> None: assert test_plugin.name() == plugin_name -@pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) -) +@pytest.mark.parametrize("test_plugin", extend_fixtures()) def test_plugin_type(test_plugin: Plugin) -> None: assert test_plugin.type() == "generic" @pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) + "test_plugin", extend_fixtures(is_classification=[True, False]) ) def test_plugin_fit(test_plugin: Plugin) -> None: X = pd.DataFrame(load_iris()["data"]) @@ -53,7 +66,7 @@ def test_plugin_fit(test_plugin: Plugin) -> None: @pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) + "test_plugin", extend_fixtures(is_classification=[True, False]) ) def test_plugin_generate(test_plugin: Plugin) -> None: X = pd.DataFrame(load_iris()["data"]) @@ -69,7 +82,7 @@ def test_plugin_generate(test_plugin: Plugin) -> None: @pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) + "test_plugin", extend_fixtures(is_classification=[True, False]) ) def test_plugin_generate_constraints(test_plugin: Plugin) -> None: X = pd.DataFrame(load_iris()["data"]) @@ -100,9 +113,7 @@ def test_plugin_generate_constraints(test_plugin: Plugin) -> None: assert list(X_gen.columns) == list(X.columns) -@pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) -) +@pytest.mark.parametrize("test_plugin", extend_fixtures()) def test_plugin_hyperparams(test_plugin: Plugin) -> None: assert len(test_plugin.hyperparameter_space()) == 6 @@ -123,7 +134,7 @@ def test_eval_performance_ddpm(compress_dataset: bool) -> None: X = GenericDataLoader(Xraw) for _ in range(2): - test_plugin = plugin(**plugin_args, compress_dataset=compress_dataset) + test_plugin = plugin(**plugin_params, compress_dataset=compress_dataset) evaluator = PerformanceEvaluatorXGB() test_plugin.fit(X) From 52be80f2ecb3aab0c9a3a5cd782744814127872d Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 16 Mar 2023 15:34:47 +0100 Subject: [PATCH 26/95] update EarlyStopping --- src/synthcity/utils/callbacks.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/synthcity/utils/callbacks.py b/src/synthcity/utils/callbacks.py index fa54074c..0e6b15cf 100644 --- a/src/synthcity/utils/callbacks.py +++ b/src/synthcity/utils/callbacks.py @@ -1,6 +1,5 @@ # stdlib from abc import ABC, abstractmethod -from typing import Optional # third party import numpy as np @@ -34,25 +33,29 @@ def on_fit_end(self, model: nn.Module) -> None: class EarlyStopping(Callback): def __init__( self, + patience_metric: WeightedMetrics, patience: int = 5, min_epochs: int = 100, - patience_metric: Optional[WeightedMetrics] = None, ) -> None: self.patience = patience - self.patience_metric = patience_metric self.min_epochs = min_epochs + self.patience_metric = patience_metric self.best_score = self._init_patience_score() self.best_model_state = None self.wait = 0 self._epochs = 0 + def on_fit_begin(self, model: nn.Module) -> None: + self.X_val = model.X_val + if isinstance(self.X_val, Tensor): + self.X_val = self.X_val.detach().cpu().numpy() + + def on_epoch_begin(self, model: nn.Module) -> None: + pass + def on_epoch_end(self, model: nn.Module) -> None: self._epochs += 1 if self.patience_metric is not None: - if not hasattr(self, "X_val"): - self.X_val = model.X_val - if isinstance(self.X_val, Tensor): - self.X_val = self.X_val.detach().cpu().numpy() self._evaluate_patience_metric(model) if self.wait >= self.patience and self._epochs >= self.min_epochs: raise StopIteration("Early stopping") @@ -62,9 +65,7 @@ def on_fit_end(self, model: nn.Module) -> None: model.load_state_dict(self.best_model_state) # type: ignore def _init_patience_score(self) -> float: - if self.patience_metric is None: - return 0 - elif self.patience_metric.direction() == "minimize": + if self.patience_metric.direction() == "minimize": return np.inf else: return -np.inf @@ -73,12 +74,12 @@ def _evaluate_patience_metric(self, model: nn.Module) -> None: X_val = self.X_val X_syn = model.generate(len(X_val)) - new_score = self.patience_metric.evaluate( # type: ignore + new_score = self.patience_metric.evaluate( pd.DataFrame(X_val), pd.DataFrame(X_syn), ) - if self.patience_metric.direction() == "minimize": # type: ignore + if self.patience_metric.direction() == "minimize": is_new_best = new_score < self.best_score else: is_new_best = new_score > self.best_score From 794ebd61be9fc966871a419604d0813ed3a35bee Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 27 Mar 2023 20:52:17 +0200 Subject: [PATCH 27/95] add TabDDPM tutorial, update TabDDPM plugin and encoders --- .../plugins/core/models/data_encoder.py | 150 +- .../core/models/tabular_ddpm/__init__.py | 22 +- .../gaussian_multinomial_diffsuion.py | 4 +- .../core/models/tabular_ddpm/modules.py | 10 +- .../core/models/tabular_ddpm/nn_utils.py | 169 ++ .../plugins/core/models/tabular_ddpm/utils.py | 541 +++-- .../plugins/core/models/tabular_encoder.py | 242 +-- src/synthcity/plugins/generic/plugin_ddpm.py | 18 +- src/synthcity/utils/dataframe.py | 12 + ...al8_tabular_modelling_with_diffusion.ipynb | 1936 +++++++++++++++++ 10 files changed, 2763 insertions(+), 341 deletions(-) create mode 100644 src/synthcity/plugins/core/models/tabular_ddpm/nn_utils.py create mode 100644 tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb diff --git a/src/synthcity/plugins/core/models/data_encoder.py b/src/synthcity/plugins/core/models/data_encoder.py index 75915afb..9f432d9c 100644 --- a/src/synthcity/plugins/core/models/data_encoder.py +++ b/src/synthcity/plugins/core/models/data_encoder.py @@ -1,40 +1,83 @@ # stdlib -from typing import Any, List, Optional +from functools import wraps +from typing import Any, List, Optional, Union # third party import numpy as np import pandas as pd from pydantic import validate_arguments +from sklearn.base import BaseEstimator, TransformerMixin from sklearn.mixture import BayesianGaussianMixture +from sklearn.preprocessing import ( + MinMaxScaler, + OneHotEncoder, + QuantileTransformer, + StandardScaler, +) -class DatetimeEncoder: - """Datetime encoder, with sklearn-style API""" - - def __init__(self) -> None: - pass +class _DataEncoder(TransformerMixin, BaseEstimator): + """Base data encoder, with sklearn-style API""" @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def fit(self, X: pd.Series) -> Any: + def fit(self, X: Any) -> Any: + return self._fit(X) + + def _fit(self, X: Any) -> Any: return self @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def transform(self, X: pd.Series) -> pd.Series: - out = pd.to_numeric(X).astype(float) - return out + def transform(self, X: Any) -> Any: + return self._transform(X) + + def _transform(self, X: Any) -> Any: + return X @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def inverse_transform(self, X: pd.Series) -> pd.Series: - out = pd.to_datetime(X) - return out + def inverse_transform(self, X: Any) -> Any: + return self._inverse_transform(X) + + def _inverse_transform(self, X: Any) -> Any: + return X @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def fit_transform(self, X: pd.Series) -> pd.Series: + def fit_transform(self, X: Any) -> Any: return self.fit(X).transform(X) + @classmethod + def wraps(cls, encoder_class: TransformerMixin) -> type: + """Wraps sklearn encoder to DataEncoder.""" + + @wraps(encoder_class) + class WrappedEncoder(_DataEncoder): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.encoder = encoder_class(*args, **kwargs) + + def _fit(self, X: Any) -> _DataEncoder: + self.encoder.fit(X) + return self + + def _transform(self, X: Any) -> Any: + return self.encoder.transform(X) + + def _inverse_transform(self, X: Any) -> Any: + return self.encoder.inverse_transform(X) + + return WrappedEncoder + + +class DatetimeEncoder(_DataEncoder): + """Datetime variables encoder""" + + def _transform(self, X: pd.Series) -> pd.Series: + return pd.to_numeric(X).astype(float) + + def _inverse_transform(self, X: pd.Series) -> pd.Series: + return pd.to_datetime(X) + -class ContinuousDataEncoder: - """Continuous variables encoder""" +class BayesianGMMEncoder(_DataEncoder): + """Bayesian Gaussian Mixture encoder""" def __init__( self, @@ -52,18 +95,17 @@ def __init__( self.weights: Optional[List[float]] = None self.std_multiplier = 4 - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def fit(self, X: pd.Series) -> Any: + def _fit(self, X: pd.DataFrame) -> Any: self.min_value = X.min() self.max_value = X.max() self.model.fit(X.values.reshape(-1, 1)) self.weights = self.model.weights_ + self.n_components = len(self.model.weights_) return self - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def transform(self, X: pd.Series) -> pd.DataFrame: + def _transform(self, X: pd.DataFrame) -> pd.DataFrame: name = X.name X = X.values.reshape(-1, 1) means = self.model.means_.reshape(1, self.n_components) @@ -85,8 +127,7 @@ def transform(self, X: pd.Series) -> pd.DataFrame: return pd.DataFrame(out, columns=[f"{name}.value", f"{name}.component"]) - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def inverse_transform(self, X: pd.DataFrame) -> pd.Series: + def _inverse_transform(self, X: pd.DataFrame) -> pd.DataFrame: normalized = np.clip(X.values[:, 0], -1, 1) means = self.model.means_.reshape([-1]) stds = np.sqrt(self.model.covariances_).reshape([-1]) @@ -100,11 +141,62 @@ def inverse_transform(self, X: pd.DataFrame) -> pd.Series: # clip values return np.clip(reversed_data, self.min_value, self.max_value) - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def fit_transform(self, X: pd.Series) -> pd.Series: - return self.fit(X).transform(X) - def components(self) -> int: - if self.weights is None: - raise RuntimeError("Train the model first") - return len(self.weights) +OneHotEncoder = _DataEncoder.wraps(OneHotEncoder) +StandardScaler = _DataEncoder.wraps(StandardScaler) +MinMaxScaler = _DataEncoder.wraps(MinMaxScaler) + + +@_DataEncoder.wraps +class GaussianQuantileTransformer(QuantileTransformer): + """Quantile transformer with Gaussian distribution""" + + def __init__( + self, + *, + ignore_implicit_zeros: bool = False, + subsample: int = 10000, + random_state: Any = None, + copy: bool = True, + ): + super().__init__( + n_quantiles=None, + output_distribution="normal", + ignore_implicit_zeros=ignore_implicit_zeros, + subsample=subsample, + random_state=random_state, + copy=copy, + ) + + def fit(self, X: pd.DataFrame, y: Any = None) -> "GaussianQuantileTransformer": + self.n_quantiles = max(min(len(X) // 30, 1000), 10) + return super().fit(X, y) + + +REGISTRY = { + "datetime": DatetimeEncoder, + "onehot": OneHotEncoder, + "standard": StandardScaler, + "minmax": MinMaxScaler, + "quantile": GaussianQuantileTransformer, + "bayesian_gmm": BayesianGMMEncoder, +} + + +def get_encoder(encoder: Union[str, type]) -> TransformerMixin: + """Get a registered encoder. + + Supported encoders: + - Datetime + - datetime + - Categorical + - onehot + - Continuous + - standard + - minmax + - quantile + - bayesian_gmm + """ + if isinstance(encoder, type): # custom encoder + return encoder + return REGISTRY[encoder] diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index d80c2a85..d6141f81 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -99,7 +99,7 @@ def fit( model_params = dict( num_classes=self.n_classes, - use_label=cond is not None, + conditional=cond is not None, mlp_params=self.mlp_params, dim_emb=self.dim_embed, ) @@ -139,7 +139,7 @@ def fit( for cbk in self.callbacks: cbk.on_fit_begin(self) - self.loss_history = pd.DataFrame(columns=["step", "mloss", "gloss", "loss"]) + self.loss_history = [] steps = 0 curr_loss_multi = 0.0 @@ -174,12 +174,14 @@ def fit( info( f"Step {steps}: MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}" ) - self.loss_history.loc[len(self.loss_history)] = [ - steps, - mloss, - gloss, - mloss + gloss, - ] + self.loss_history.append( + [ + steps, + mloss, + gloss, + mloss + gloss, + ] + ) curr_count = 0 curr_loss_gauss = 0.0 curr_loss_multi = 0.0 @@ -196,6 +198,10 @@ def fit( info(f"Early stopped at epoch {epoch}") break + self.loss_history = pd.DataFrame( + self.loss_history, columns=["step", "mloss", "gloss", "loss"] + ).set_index("step") + for cbk in self.callbacks: cbk.on_fit_end(self) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 270d8b03..db55aedb 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -19,7 +19,7 @@ # synthcity relative from .modules import MLPDiffusion, ResNetDiffusion -from .utils import ( +from .nn_utils import ( discretized_gaussian_log_likelihood, index_to_log_onehot, log_1_min_a, @@ -112,7 +112,7 @@ def __init__( if model_params is None: model_params = dict( - dim_in=self.dim_input, num_classes=0, use_label=False, mlp_params=None + dim_in=self.dim_input, num_classes=0, conditional=False, mlp_params=None ) else: model_params["dim_in"] = self.dim_input diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index 297c01bf..8d3a777b 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -67,7 +67,7 @@ def __init__( dim_emb: int = 128, *, mlp_params: dict = {}, - use_label: bool = False, + conditional: bool = False, num_classes: int = 0, emb_nonlin: Union[str, nn.Module] = "silu", max_time_period: int = 10000, @@ -75,7 +75,7 @@ def __init__( super().__init__() self.dim_t = dim_emb self.num_classes = num_classes - self.has_label = use_label + self.has_label = conditional if isinstance(emb_nonlin, str): self.emb_nonlin = get_nonlin(emb_nonlin) @@ -85,7 +85,7 @@ def __init__( self.proj = nn.Linear(dim_in, dim_emb) self.time_emb = TimeStepEmbedding(dim_emb, max_time_period) - if use_label: + if conditional: if self.num_classes > 0: self.label_emb = nn.Embedding(self.num_classes, dim_emb) elif self.num_classes == 0: # regression @@ -103,9 +103,9 @@ def forward(self, x: Tensor, t: Tensor, y: Optional[Tensor] = None) -> Tensor: emb = self.time_emb(t) if self.has_label: if y is None: - raise ValueError("y must be provided if use_label is True") + raise ValueError("y must be provided if conditional is True") if self.num_classes == 0: - y = y.resize(-1, 1).float() + y = y.reshape(-1, 1).float() else: y = y.squeeze().long() emb += self.emb_nonlin(self.label_emb(y)) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/nn_utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/nn_utils.py new file mode 100644 index 00000000..04eb9d8f --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_ddpm/nn_utils.py @@ -0,0 +1,169 @@ +# future +from __future__ import annotations + +# third party +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + + +def normal_kl(mean1: Tensor, logvar1: Tensor, mean2: Tensor, logvar2: Tensor) -> Tensor: + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + try: + tensor = next( + x for x in (mean1, logvar1, mean2, logvar2) if isinstance(x, Tensor) + ) + except StopIteration: + raise TypeError("at least one argument must be a Tensor") + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x: Tensor) -> Tensor: + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * ( + 1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))) + ) + + +def discretized_gaussian_log_likelihood( + x: Tensor, *, means: Tensor, log_scales: Tensor +) -> Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + if not (x.shape == means.shape == log_scales.shape): + raise ValueError("shapes must match") + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x < -0.999, + log_cdf_plus, + torch.where( + x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)) + ), + ) + if log_probs.shape != x.shape: + raise ValueError("shapes must match") + return log_probs + + +def sum_except_batch(x: Tensor, num_dims: int = 1) -> Tensor: + """ + Sums all dimensions except the first. + + Args: + x: Tensor, shape (batch_size, ...) + num_dims: int, number of batch dims (default=1) + + Returns: + x_sum: Tensor, shape (batch_size,) + """ + return x.reshape(*x.shape[:num_dims], -1).sum(-1) + + +def mean_flat(tensor: Tensor) -> Tensor: + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def ohe_to_categories(ohe: Tensor, K: np.ndarray) -> Tensor: + K = torch.from_numpy(K) + indices = torch.cat([torch.zeros((1,)), K.cumsum(dim=0)], dim=0).int().tolist() + res = [] + for i in range(len(indices) - 1): + res.append(ohe[:, indices[i] : indices[i + 1]].argmax(dim=1)) + return torch.stack(res, dim=1) + + +def log_1_min_a(a: Tensor) -> Tensor: + return torch.log(1 - a.exp() + 1e-40) + + +def log_add_exp(a: Tensor, b: Tensor) -> Tensor: + """Numerically stable log(exp(a) + exp(b)).""" + m = torch.max(a, b) + return m + torch.log(torch.exp(a - m) + torch.exp(b - m)) + + +def perm_and_expand(a: Tensor, t: Tensor, x_shape: tuple) -> Tensor: + """Permutes a tensor in the order specified by `t` and expands it to `x_shape`.""" + if not (a.ndim == 1 and t.shape == (x_shape[0],)): + raise ValueError(f"dimensionality mismatch: {a.shape}, {t.shape}, {x_shape}") + out = a[t] + while len(out.shape) < len(x_shape): + out = out[..., None] + return out.expand(x_shape) + + +def log_categorical(log_x_start: Tensor, log_prob: Tensor) -> Tensor: + return (log_x_start.exp() * log_prob).sum(dim=1) + + +def index_to_log_onehot(x: Tensor, num_classes: np.ndarray) -> Tensor: + onehots = [] + for i in range(len(num_classes)): + onehots.append(F.one_hot(x[:, i], num_classes[i])) + x_onehot = torch.cat(onehots, dim=1) + log_onehot = torch.log(x_onehot.float().clamp(min=1e-30)) + return log_onehot + + +@torch.jit.script +def log_sub_exp(a: Tensor, b: Tensor) -> Tensor: + m = torch.maximum(a, b) + return torch.log(torch.exp(a - m) - torch.exp(b - m)) + m + + +@torch.jit.script +def sliced_logsumexp(x: Tensor, slices: Tensor) -> Tensor: + lse = torch.logcumsumexp( + torch.nn.functional.pad(x, [1, 0, 0, 0], value=-float("inf")), dim=-1 + ) + + slice_starts = slices[:-1] + slice_ends = slices[1:] + + slice_lse = log_sub_exp(lse[:, slice_ends], lse[:, slice_starts]) + slice_lse_repeated = torch.repeat_interleave( + slice_lse, slice_ends - slice_starts, dim=-1 + ) + return slice_lse_repeated diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py index 04eb9d8f..8574ffec 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/utils.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/utils.py @@ -1,169 +1,408 @@ -# future -from __future__ import annotations +# mypy: ignore-errors + +# stdlib +from collections import Counter +from copy import deepcopy +from dataclasses import dataclass, replace +from typing import Any, Dict, Literal, Optional, Tuple, Union, cast # third party import numpy as np +import pandas as pd +import sklearn.preprocessing import torch -import torch.nn.functional as F -from torch import Tensor - - -def normal_kl(mean1: Tensor, logvar1: Tensor, mean2: Tensor, logvar2: Tensor) -> Tensor: - """ - Compute the KL divergence between two gaussians. - - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - try: - tensor = next( - x for x in (mean1, logvar1, mean2, logvar2) if isinstance(x, Tensor) +from sklearn.impute import SimpleImputer + +# synthcity absolute +from synthcity.utils.dataframe import TaskType + +ArrayDict = Dict[str, np.ndarray] +TensorDict = Dict[str, torch.Tensor] + + +CAT_MISSING_VALUE = "__nan__" +CAT_RARE_VALUE = "__rare__" +Normalization = Literal["standard", "quantile", "minmax"] +NumNanPolicy = Literal["drop-rows", "mean"] +CatNanPolicy = Literal["most_frequent"] + + +@dataclass(frozen=False) +class Dataset: + X_num: Optional[ArrayDict] + X_cat: Optional[ArrayDict] + y: ArrayDict + task_type: TaskType + n_classes: Optional[int] + + @property + def is_binclass(self) -> bool: + return self.task_type == TaskType.BINARY + + @property + def is_multiclass(self) -> bool: + return self.task_type == TaskType.MULTICLASS + + @property + def is_regression(self) -> bool: + return self.task_type == TaskType.REGRESSION + + @property + def n_num_features(self) -> int: + return 0 if self.X_num is None else self.X_num["train"].shape[1] + + @property + def n_cat_features(self) -> int: + return 0 if self.X_cat is None else self.X_cat["train"].shape[1] + + @property + def n_features(self) -> int: + return self.n_num_features + self.n_cat_features + + def size(self, part: Optional[str]) -> int: + return sum(map(len, self.y.values())) if part is None else len(self.y[part]) + + @property + def nn_output_dim(self) -> int: + if self.is_multiclass: + assert self.n_classes is not None + return self.n_classes + else: + return 1 + + +def num_process_nans(dataset: Dataset, policy: Optional[NumNanPolicy]) -> Dataset: + assert dataset.X_num is not None + nan_masks = {k: np.isnan(v) for k, v in dataset.X_num.items()} + if not any(x.any() for x in nan_masks.values()): + assert policy is None + return dataset + + assert policy is not None + if policy == "drop-rows": + valid_masks = {k: ~v.any(1) for k, v in nan_masks.items()} + assert valid_masks[ + "test" + ].all(), "Cannot drop test rows, since this will affect the final metrics." + new_data = {} + for data_name in ["X_num", "X_cat", "y"]: + data_dict = getattr(dataset, data_name) + if data_dict is not None: + new_data[data_name] = { + k: v[valid_masks[k]] for k, v in data_dict.items() + } + dataset = replace(dataset, **new_data) + elif policy == "mean": + new_values = np.nanmean(dataset.X_num["train"], axis=0) + X_num = deepcopy(dataset.X_num) + for k, v in X_num.items(): + num_nan_indices = np.where(nan_masks[k]) + v[num_nan_indices] = np.take(new_values, num_nan_indices[1]) + dataset = replace(dataset, X_num=X_num) + else: + assert raise_unknown("policy", policy) + return dataset + + +# Inspired by: https://github.com/yandex-research/rtdl/blob/a4c93a32b334ef55d2a0559a4407c8306ffeeaee/lib/data.py#L20 +def normalize( + X: ArrayDict, + normalization: Normalization, + seed: Optional[int], + return_normalizer: bool = False, +) -> ArrayDict: + X_train = X["train"] + if normalization == "standard": + normalizer = sklearn.preprocessing.StandardScaler() + elif normalization == "minmax": + normalizer = sklearn.preprocessing.MinMaxScaler() + elif normalization == "quantile": + normalizer = sklearn.preprocessing.QuantileTransformer( + output_distribution="normal", + n_quantiles=max(min(X["train"].shape[0] // 30, 1000), 10), + subsample=1e9, + random_state=seed, ) - except StopIteration: - raise TypeError("at least one argument must be a Tensor") - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + else: + raise_unknown("normalization", normalization) + normalizer.fit(X_train) + if return_normalizer: + return {k: normalizer.transform(v) for k, v in X.items()}, normalizer + return {k: normalizer.transform(v) for k, v in X.items()} + + +def cat_process_nans(X: ArrayDict, policy: Optional[CatNanPolicy]) -> ArrayDict: + assert X is not None + nan_masks = {k: v == CAT_MISSING_VALUE for k, v in X.items()} + if any(x.any() for x in nan_masks.values()): + if policy is None: + X_new = X + elif policy == "most_frequent": + imputer = SimpleImputer(missing_values=CAT_MISSING_VALUE, strategy=policy) + imputer.fit(X["train"]) + X_new = {k: cast(np.ndarray, imputer.transform(v)) for k, v in X.items()} + else: + raise_unknown("categorical NaN policy", policy) + else: + assert policy is None + X_new = X + return X_new + + +def cat_drop_rare(X: ArrayDict, min_frequency: float) -> ArrayDict: + assert 0.0 < min_frequency < 1.0 + min_count = round(len(X["train"]) * min_frequency) + X_new = {x: [] for x in X} + for column_idx in range(X["train"].shape[1]): + counter = Counter(X["train"][:, column_idx].tolist()) + popular_categories = {k for k, v in counter.items() if v >= min_count} + for part in X_new: + X_new[part].append( + [ + (x if x in popular_categories else CAT_RARE_VALUE) + for x in X[part][:, column_idx].tolist() + ] + ) + return {k: np.array(v).T for k, v in X_new.items()} + + +def build_target(y: ArrayDict, task_type: TaskType) -> Tuple[ArrayDict, Dict[str, Any]]: + info: Dict[str, Any] = {} + if task_type == TaskType.REGRESSION: + mean, std = float(y["train"].mean()), float(y["train"].std()) + y = {k: (v - mean) / std for k, v in y.items()} + info["mean"] = mean + info["std"] = std + return y, info + + +@dataclass(frozen=True) +class Transformations: + seed: int = 0 + normalization: Optional[Normalization] = None + num_nan_policy: Optional[NumNanPolicy] = None + cat_nan_policy: Optional[CatNanPolicy] = None + cat_min_frequency: Optional[float] = None + + +def transform_dataset( + dataset: Dataset, + transformations: Transformations, +) -> Dataset: + # WARNING: the order of transformations matters. Moreover, the current + # implementation is not ideal in that sense. + + if dataset.X_num is not None: + dataset = num_process_nans(dataset, transformations.num_nan_policy) + + num_transform = None + cat_transform = None + X_num = dataset.X_num + + if X_num is not None and transformations.normalization is not None: + X_num, num_transform = normalize( + X_num, + transformations.normalization, + transformations.seed, + return_normalizer=True, + ) + num_transform = num_transform + + if dataset.X_cat is None: + assert transformations.cat_nan_policy is None + assert transformations.cat_min_frequency is None + # assert transformations.cat_encoding is None + X_cat = None + else: + X_cat = cat_process_nans(dataset.X_cat, transformations.cat_nan_policy) + if transformations.cat_min_frequency is not None: + X_cat = cat_drop_rare(X_cat, transformations.cat_min_frequency) + + y, y_info = build_target(dataset.y, dataset.task_type) + + dataset = replace(dataset, X_num=X_num, X_cat=X_cat, y=y) + dataset.num_transform = num_transform + dataset.cat_transform = cat_transform + + return dataset + + +def make_dataset( + df: pd.DataFrame, + target: str, + cat_counts: Dict[str, int], + T: Transformations, +) -> Dataset: + # classification + if len(cat_counts) > 0: + task_type = TaskType.CLASSIFICATION + else: + task_type = TaskType.REGRESSION + + X_cat = df[list(cat_counts.keys())] + X_num = df.drop(columns=list(X_cat.keys()) + [target]) + y = df[target] + + D = Dataset( + X_num, + X_cat, + y, + task_type=TaskType(task_type), ) - -def approx_standard_normal_cdf(x: Tensor) -> Tensor: - """ - A fast approximation of the cumulative distribution function of the - standard normal. - """ - return 0.5 * ( - 1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))) - ) + return transform_dataset(D, T, None) -def discretized_gaussian_log_likelihood( - x: Tensor, *, means: Tensor, log_scales: Tensor -) -> Tensor: - """ - Compute the log-likelihood of a Gaussian distribution discretizing to a - given image. - - :param x: the target images. It is assumed that this was uint8 values, - rescaled to the range [-1, 1]. - :param means: the Gaussian mean Tensor. - :param log_scales: the Gaussian log stddev Tensor. - :return: a tensor like x of log probabilities (in nats). - """ - if not (x.shape == means.shape == log_scales.shape): - raise ValueError("shapes must match") - centered_x = x - means - inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 1.0 / 255.0) - cdf_plus = approx_standard_normal_cdf(plus_in) - min_in = inv_stdv * (centered_x - 1.0 / 255.0) - cdf_min = approx_standard_normal_cdf(min_in) - log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) - log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) - cdf_delta = cdf_plus - cdf_min - log_probs = torch.where( - x < -0.999, - log_cdf_plus, - torch.where( - x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)) - ), +def prepare_tensors( + dataset: Dataset, device: Union[str, torch.device] +) -> Tuple[Optional[TensorDict], Optional[TensorDict], TensorDict]: + X_num, X_cat, Y = ( + None if x is None else {k: torch.as_tensor(v) for k, v in x.items()} + for x in [dataset.X_num, dataset.X_cat, dataset.y] ) - if log_probs.shape != x.shape: - raise ValueError("shapes must match") - return log_probs - - -def sum_except_batch(x: Tensor, num_dims: int = 1) -> Tensor: - """ - Sums all dimensions except the first. - - Args: - x: Tensor, shape (batch_size, ...) - num_dims: int, number of batch dims (default=1) - - Returns: - x_sum: Tensor, shape (batch_size,) - """ - return x.reshape(*x.shape[:num_dims], -1).sum(-1) - - -def mean_flat(tensor: Tensor) -> Tensor: - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def ohe_to_categories(ohe: Tensor, K: np.ndarray) -> Tensor: - K = torch.from_numpy(K) - indices = torch.cat([torch.zeros((1,)), K.cumsum(dim=0)], dim=0).int().tolist() - res = [] - for i in range(len(indices) - 1): - res.append(ohe[:, indices[i] : indices[i + 1]].argmax(dim=1)) - return torch.stack(res, dim=1) - - -def log_1_min_a(a: Tensor) -> Tensor: - return torch.log(1 - a.exp() + 1e-40) - - -def log_add_exp(a: Tensor, b: Tensor) -> Tensor: - """Numerically stable log(exp(a) + exp(b)).""" - m = torch.max(a, b) - return m + torch.log(torch.exp(a - m) + torch.exp(b - m)) - - -def perm_and_expand(a: Tensor, t: Tensor, x_shape: tuple) -> Tensor: - """Permutes a tensor in the order specified by `t` and expands it to `x_shape`.""" - if not (a.ndim == 1 and t.shape == (x_shape[0],)): - raise ValueError(f"dimensionality mismatch: {a.shape}, {t.shape}, {x_shape}") - out = a[t] - while len(out.shape) < len(x_shape): - out = out[..., None] - return out.expand(x_shape) - - -def log_categorical(log_x_start: Tensor, log_prob: Tensor) -> Tensor: - return (log_x_start.exp() * log_prob).sum(dim=1) - + if device.type != "cpu": + X_num, X_cat, Y = ( + None if x is None else {k: v.to(device) for k, v in x.items()} + for x in [X_num, X_cat, Y] + ) + assert X_num is not None + assert Y is not None + if not dataset.is_multiclass: + Y = {k: v.float() for k, v in Y.items()} + return X_num, X_cat, Y -def index_to_log_onehot(x: Tensor, num_classes: np.ndarray) -> Tensor: - onehots = [] - for i in range(len(num_classes)): - onehots.append(F.one_hot(x[:, i], num_classes[i])) - x_onehot = torch.cat(onehots, dim=1) - log_onehot = torch.log(x_onehot.float().clamp(min=1e-30)) - return log_onehot +############## +# DataLoader # +############## -@torch.jit.script -def log_sub_exp(a: Tensor, b: Tensor) -> Tensor: - m = torch.maximum(a, b) - return torch.log(torch.exp(a - m) - torch.exp(b - m)) + m +class TabDataset(torch.utils.data.Dataset): + def __init__(self, dataset: Dataset, split: Literal["train", "val", "test"]): + super().__init__() -@torch.jit.script -def sliced_logsumexp(x: Tensor, slices: Tensor) -> Tensor: - lse = torch.logcumsumexp( - torch.nn.functional.pad(x, [1, 0, 0, 0], value=-float("inf")), dim=-1 + self.X_num = ( + torch.from_numpy(dataset.X_num[split]) + if dataset.X_num is not None + else None + ) + self.X_cat = ( + torch.from_numpy(dataset.X_cat[split]) + if dataset.X_cat is not None + else None + ) + self.y = torch.from_numpy(dataset.y[split]) + + assert self.y is not None + assert self.X_num is not None or self.X_cat is not None + + def __len__(self) -> int: + return len(self.y) + + def __getitem__(self, idx): + out_dict = { + "y": self.y[idx].long() if self.y is not None else None, + } + + x = np.empty((0,)) + if self.X_num is not None: + x = self.X_num[idx] + if self.X_cat is not None: + x = torch.cat([x, self.X_cat[idx]], dim=0) + return x.float(), out_dict + + +# def prepare_dataloader( +# dataset: Dataset, +# split: str, +# batch_size: int, +# ): +# torch_dataset = TabDataset(dataset, split) +# loader = torch.utils.data.DataLoader( +# torch_dataset, +# batch_size=batch_size, +# shuffle=(split == "train"), +# num_workers=1, +# ) +# while True: +# yield from loader + + +# def prepare_torch_dataloader( +# dataset: Dataset, +# split: str, +# shuffle: bool, +# batch_size: int, +# ) -> torch.utils.data.DataLoader: + +# torch_dataset = TabDataset(dataset, split) +# loader = torch.utils.data.DataLoader( +# torch_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1 +# ) + +# return loader + + +def concat_features(D: Dataset): + if D.X_num is None: + assert D.X_cat is not None + X = { + k: pd.DataFrame(v, columns=range(D.n_features)) for k, v in D.X_cat.items() + } + elif D.X_cat is None: + assert D.X_num is not None + X = { + k: pd.DataFrame(v, columns=range(D.n_features)) for k, v in D.X_num.items() + } + else: + X = { + part: pd.concat( + [ + pd.DataFrame(D.X_num[part], columns=range(D.n_num_features)), + pd.DataFrame( + D.X_cat[part], + columns=range(D.n_num_features, D.n_features), + ), + ], + axis=1, + ) + for part in D.y.keys() + } + + return X + + +def concat_to_pd(X_num, X_cat, y): + if X_num is None: + return pd.concat( + [ + pd.DataFrame(X_cat, columns=list(range(X_cat.shape[1]))), + pd.DataFrame(y, columns=["y"]), + ], + axis=1, + ) + if X_cat is not None: + return pd.concat( + [ + pd.DataFrame(X_num, columns=list(range(X_num.shape[1]))), + pd.DataFrame( + X_cat, + columns=list( + range(X_num.shape[1], X_num.shape[1] + X_cat.shape[1]) + ), + ), + pd.DataFrame(y, columns=["y"]), + ], + axis=1, + ) + return pd.concat( + [ + pd.DataFrame(X_num, columns=list(range(X_num.shape[1]))), + pd.DataFrame(y, columns=["y"]), + ], + axis=1, ) - slice_starts = slices[:-1] - slice_ends = slices[1:] - slice_lse = log_sub_exp(lse[:, slice_ends], lse[:, slice_starts]) - slice_lse_repeated = torch.repeat_interleave( - slice_lse, slice_ends - slice_starts, dim=-1 - ) - return slice_lse_repeated +def raise_unknown(unknown_what: str, unknown_value: Any): + raise ValueError(f"Unknown {unknown_what}: {unknown_value}") diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 360fcf56..a9bb0e82 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -9,7 +9,6 @@ import pandas as pd from pydantic import BaseModel, validate_arguments, validator from sklearn.base import BaseEstimator, TransformerMixin -from sklearn.preprocessing import MinMaxScaler, OneHotEncoder # synthcity absolute import synthcity.logger as log @@ -17,7 +16,7 @@ from synthcity.utils.serialization import dataframe_hash # synthcity relative -from .data_encoder import ContinuousDataEncoder +from .data_encoder import get_encoder class FeatureInfo(BaseModel): @@ -50,107 +49,6 @@ def _output_dimensions_validator(cls: Any, v: int) -> int: return v -class BinEncoder(TransformerMixin, BaseEstimator): - """Binary encoder (for SurvivalGAN). - - Model continuous columns with a BayesianGMM and normalized to a scalar [0, 1] and a vector. - Discrete columns are encoded using a scikit-learn OneHotEncoder. - """ - - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def __init__( - self, - max_clusters: int = 10, - categorical_limit: int = 10, - ) -> None: - """Create a data transformer. - - Args: - max_clusters (int): - Maximum number of Gaussian distributions in Bayesian GMM. - """ - self.max_clusters = max_clusters - self.categorical_limit = categorical_limit - - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def _fit_continuous(self, data: pd.Series) -> FeatureInfo: - """Train Bayesian GMM for continuous columns. - - Args: - data (pd.Series): - A dataframe containing a column. - - Returns: - namedtuple: - A ``FeatureInfo`` object. - """ - name = data.name - encoder = ContinuousDataEncoder( - n_components=min(self.max_clusters, len(data)), - ) - encoder.fit(data) - num_components = encoder.components() - - transformed_features = [f"{name}.value"] + [ - f"{name}.component_{i}" for i in range(num_components) - ] - - return FeatureInfo( - name=name, - feature_type="continuous", - transform=encoder, - output_dimensions=1 + num_components, - transformed_features=transformed_features, - ) - - def fit( - self, raw_data: pd.DataFrame, discrete_columns: Optional[List] = None - ) -> "BinEncoder": - """Fit the ``BinEncoder``. - - Fits a ``ContinuousDataEncoder`` for continuous columns - """ - if discrete_columns is None: - discrete_columns = find_cat_cols(raw_data, self.categorical_limit) - - self.output_dimensions = 0 - - self._column_transform_info = {} - for name in raw_data.columns: - if name not in discrete_columns: - column_transform_info = self._fit_continuous(raw_data[name]) - self._column_transform_info[name] = column_transform_info - - return self - - def _transform_continuous( - self, column_transform_info: FeatureInfo, data: pd.Series - ) -> pd.Series: - name = data.name - encoder = column_transform_info.transform - transformed = encoder.transform(data) - - return transformed[f"{name}.component"].to_numpy().astype(int) - - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def transform(self, raw_data: pd.DataFrame) -> pd.DataFrame: - """Take raw data and output a matrix data.""" - output = raw_data.copy() - - for name in self._column_transform_info: - column_transform_info = self._column_transform_info[name] - - output[name] = self._transform_continuous( - column_transform_info, raw_data[name] - ) - - return output - - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def fit_transform(self, raw_data: pd.DataFrame) -> pd.DataFrame: - return self.fit(raw_data).transform(raw_data) - - class TabularEncoder(TransformerMixin, BaseEstimator): """Tabular encoder. @@ -164,6 +62,8 @@ def __init__( max_clusters: int = 10, categorical_limit: int = 10, whitelist: list = [], + categorical_encoder: str = "onehot", + continuous_encoder: str = "bayesian_gmm", ) -> None: """Create a data transformer. @@ -174,10 +74,12 @@ def __init__( self.max_clusters = max_clusters self.categorical_limit = categorical_limit self.whitelist = whitelist + self.categorical_encoder = categorical_encoder + self.continuous_encoder = continuous_encoder @validate_arguments(config=dict(arbitrary_types_allowed=True)) def _fit_continuous(self, data: pd.Series) -> FeatureInfo: - """Train Bayesian GMM for continuous columns. + """Fit the continuous encoder on a continuous column. Args: data (pd.DataFrame): @@ -188,20 +90,28 @@ def _fit_continuous(self, data: pd.Series) -> FeatureInfo: A ``FeatureInfo`` object. """ name = data.name - encoder = ContinuousDataEncoder( - n_components=min(len(data), self.max_clusters), - ) + + if self.continuous_encoder == "bayesian_gmm": + encoder = get_encoder("bayesian_gmm")( + n_components=min(self.max_clusters, len(data)), + ) + n_components = encoder.n_components + dim_out = 1 + n_components + transformed_features = [f"{name}.value"] + [ + f"{name}.component_{i}" for i in range(n_components) + ] + else: + encoder = get_encoder(self.continuous_encoder)() + dim_out = 1 + transformed_features = [name] + encoder.fit(data) - num_components = encoder.components() - transformed_features = [f"{name}.value"] + [ - f"{name}.component_{i}" for i in range(num_components) - ] return FeatureInfo( name=name, feature_type="continuous", transform=encoder, - output_dimensions=1 + num_components, + output_dimensions=dim_out, transformed_features=transformed_features, ) @@ -218,16 +128,21 @@ def _fit_discrete(self, data: pd.Series) -> FeatureInfo: A ``FeatureInfo`` object. """ name = data.name - ohe = OneHotEncoder(handle_unknown="ignore", sparse=False) - ohe.fit(data.values.reshape(-1, 1)) - num_categories = len(ohe.categories_[0]) - transformed_features = list(ohe.get_feature_names_out([data.name])) + if self.categorical_encoder == "onehot": + encoder = get_encoder("onehot")(handle_unknown="ignore", sparse=False) + else: + raise ValueError(f"Unknown categorical encoder {self.categorical_encoder}") + + encoder.fit(data.values.reshape(-1, 1)) + num_categories = len(encoder.categories_[0]) + + transformed_features = list(encoder.get_feature_names_out([data.name])) return FeatureInfo( name=name, feature_type="discrete", - transform=ohe, + transform=encoder, output_dimensions=num_categories, transformed_features=transformed_features, ) @@ -238,17 +153,15 @@ def fit( ) -> Any: """Fit the ``TabularEncoder``. - Fits a ``ContinuousDataEncoder`` for continuous columns and a - ``OneHotEncoder`` for discrete columns. - This step also counts the #columns in matrix data and span information. """ if discrete_columns is None: discrete_columns = find_cat_cols(raw_data, self.categorical_limit) + self.output_dimensions = 0 self._column_raw_dtypes = raw_data.infer_objects().dtypes - self._column_transform_info_list = [] + self._column_transform_info = [] for name in raw_data.columns: if name in self.whitelist: @@ -262,7 +175,8 @@ def fit( column_transform_info = self._fit_continuous(raw_data[name]) self.output_dimensions += column_transform_info.output_dimensions - self._column_transform_info_list.append(column_transform_info) + self._column_transform_info.append(column_transform_info) + return self def _transform_continuous( @@ -273,10 +187,15 @@ def _transform_continuous( transformed = encoder.transform(data) # Converts the transformed data to the appropriate output format. - output = np.zeros((len(transformed), column_transform_info.output_dimensions)) - output[:, 0] = transformed[f"{name}.value"].to_numpy() - index = transformed[f"{name}.component"].to_numpy().astype(int) - output[np.arange(index.size), index + 1] = 1 + if self.continuous_encoder == "bayesian_gmm": + output = np.zeros( + (len(transformed), column_transform_info.output_dimensions) + ) + output[:, 0] = transformed[f"{name}.value"].to_numpy() + index = transformed[f"{name}.component"].to_numpy().astype(int) + output[np.arange(index.size), index + 1] = 1 + else: + output = transformed.to_numpy().reshape(-1, 1) return pd.DataFrame( output, @@ -286,16 +205,16 @@ def _transform_continuous( def _transform_discrete( self, column_transform_info: FeatureInfo, data: pd.Series ) -> pd.DataFrame: - ohe = column_transform_info.transform + encoder = column_transform_info.transform return pd.DataFrame( - ohe.transform(data.to_frame().values), + encoder.transform(data.to_frame().values), columns=column_transform_info.transformed_features, ) @validate_arguments(config=dict(arbitrary_types_allowed=True)) def transform(self, raw_data: pd.DataFrame) -> pd.DataFrame: """Take raw data and output a matrix data.""" - if len(self._column_transform_info_list) == 0: + if len(self._column_transform_info) == 0: return pd.DataFrame(np.zeros((len(raw_data), 0))) column_data_list = [] @@ -305,7 +224,7 @@ def transform(self, raw_data: pd.DataFrame) -> pd.DataFrame: data = raw_data[name] column_data_list.append(data) - for column_transform_info in self._column_transform_info_list: + for column_transform_info in self._column_transform_info: name = column_transform_info.name data = raw_data[name] @@ -330,18 +249,23 @@ def _inverse_transform_continuous( column_data: pd.DataFrame, ) -> pd.DataFrame: encoder = column_transform_info.transform - data = pd.DataFrame(column_data.values[:, :2], columns=["value", "component"]) - data.iloc[:, 1] = np.argmax(column_data.values[:, 1:], axis=1) + if self.continuous_encoder == "bayesian_gmm": + data = pd.DataFrame( + column_data.values[:, :2], columns=["value", "component"] + ) + data.iloc[:, 1] = np.argmax(column_data.values[:, 1:], axis=1) + else: + data = column_data return encoder.inverse_transform(data) @validate_arguments(config=dict(arbitrary_types_allowed=True)) def _inverse_transform_discrete( self, column_transform_info: FeatureInfo, column_data: pd.DataFrame ) -> pd.DataFrame: - ohe = column_transform_info.transform + encoder = column_transform_info.transform column = column_transform_info.name return pd.DataFrame( - ohe.inverse_transform(column_data), + encoder.inverse_transform(column_data), columns=[column], ) @@ -351,7 +275,7 @@ def inverse_transform(self, data: pd.DataFrame) -> pd.DataFrame: Output uses the same type as input to the transform function. """ - if len(self._column_transform_info_list) == 0: + if len(self._column_transform_info) == 0: return pd.DataFrame(np.zeros((len(data), 0))) st = 0 @@ -367,7 +291,7 @@ def inverse_transform(self, data: pd.DataFrame) -> pd.DataFrame: feature_types.append(self._column_raw_dtypes) recovered_column_data_list.append(local_data) - for column_transform_info in self._column_transform_info_list: + for column_transform_info in self._column_transform_info: dim = column_transform_info.output_dimensions column_data = data.iloc[:, list(range(st, st + dim))] if column_transform_info.feature_type == "continuous": @@ -396,18 +320,18 @@ def layout(self) -> List[Tuple]: - continuous, and with length 1 + number of GMM clusters. - discrete, and with length , the length of the one-hot encoding. """ - return self._column_transform_info_list + return self._column_transform_info def n_features(self) -> int: return np.sum( [ column_transform_info.output_dimensions - for column_transform_info in self._column_transform_info_list + for column_transform_info in self._column_transform_info ] ) def get_column_info(self, name: str) -> FeatureInfo: - for column_transform_info in self._column_transform_info_list: + for column_transform_info in self._column_transform_info: if column_transform_info.name == name: return column_transform_info @@ -424,7 +348,7 @@ def activation_layout( - discrete, and with length , the length of the one-hot encoding. """ out = [] - for column_transform_info in self._column_transform_info_list: + for column_transform_info in self._column_transform_info: if column_transform_info.feature_type == "continuous": out.extend( [ @@ -443,6 +367,35 @@ def activation_layout( return out +class BinEncoder(TabularEncoder): + """Binary encoder (for SurvivalGAN). + + Model continuous columns with a BayesianGMM and normalized to a scalar [0, 1] and a vector. + Discrete columns are encoded using a scikit-learn OneHotEncoder. + """ + + def _transform_continuous( + self, column_transform_info: FeatureInfo, data: pd.Series + ) -> pd.Series: + name = data.name + encoder = column_transform_info.transform + transformed = encoder.transform(data) + return transformed[f"{name}.component"].to_numpy().astype(int) + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def transform(self, raw_data: pd.DataFrame) -> pd.DataFrame: + """Take raw data and output a matrix data.""" + output = raw_data.copy() + + for column_transform_info in self._column_transform_info: + name = column_transform_info.name + output[name] = self._transform_continuous( + column_transform_info, raw_data[name] + ) + + return output + + class TimeSeriesTabularEncoder(TransformerMixin, BaseEstimator): """TimeSeries Tabular encoder. @@ -456,10 +409,12 @@ def __init__( max_clusters: int = 10, categorical_limit: int = 10, whitelist: list = [], + encoder: str = "minmax", ) -> None: self.max_clusters = max_clusters self.categorical_limit = categorical_limit self.whitelist = whitelist + self.encoder = encoder def fit_temporal( self, @@ -484,9 +439,8 @@ def fit_temporal( self.temporal_encoder.fit(temporal_df) # Temporal horizons - self.observation_times_encoder = MinMaxScaler().fit( - np.asarray(observation_times).reshape(-1, 1) - ) + self.observation_times_encoder = get_encoder(self.encoder) + self.observation_times_encoder.fit(np.asarray(observation_times).reshape(-1, 1)) return self @@ -672,6 +626,7 @@ def __init__( self, max_clusters: int = 10, categorical_limit: int = 10, + continuous_encoder: str = "gmm", ) -> None: """Create a data transformer. @@ -682,6 +637,7 @@ def __init__( self.encoder = BinEncoder( max_clusters=max_clusters, categorical_limit=categorical_limit, + continuous_encoder=continuous_encoder, ) def _prepare( diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 631480fc..855eb4ec 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -194,6 +194,7 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": """ df = X.dataframe() cond = kwargs.pop("cond", None) + self.loss_history = None # note that the TabularEncoder is not used in this plugin, because the # Gaussian multinomial diffusion module needs to know the number of classes @@ -208,18 +209,23 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": raise ValueError( "cond is already given by the labels for classification" ) - _, cond = X.unpack() + df, cond = X.unpack() self._labels, self._cond_dist = np.unique(cond, return_counts=True) self._cond_dist = self._cond_dist / self._cond_dist.sum() - else: + self.target_name = cond.name + self.target_iloc = list(X.columns).index(cond.name) + + if cond is not None: if type(cond) is str: cond = df[cond] + self.expecting_conditional = True if cond is not None: cond = pd.Series(cond, index=df.index) # NOTE: cond may also be included in the dataframe self.model.fit(df, cond, **kwargs) + self.loss_history = self.model.loss_history return self @@ -230,8 +236,14 @@ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader # randomly generate labels following the distribution of the training data cond = np.random.choice(self._labels, size=count, p=self._cond_dist) + if cond is not None and len(cond) > count: + raise ValueError("The length of cond is less than the required count") + def callback(count): # type: ignore - return self.model.generate(count, cond=cond) + data = self.model.generate(count, cond=cond) + if self.is_classification: + data = np.insert(data, self.target_iloc, cond, axis=1) + return data return self._safe_generate(callback, count, syn_schema, **kwargs) diff --git a/src/synthcity/utils/dataframe.py b/src/synthcity/utils/dataframe.py index a313b91e..80104e23 100644 --- a/src/synthcity/utils/dataframe.py +++ b/src/synthcity/utils/dataframe.py @@ -1,7 +1,19 @@ +# stdlib +import enum + # third party import pandas as pd +class TaskType(enum.Enum): + BINARY = "binary" + MULTICLASS = "multiclass" + REGRESSION = "regression" + + def __str__(self) -> str: + return self.value + + def constant_columns(dataframe: pd.DataFrame) -> list: """ Find constant value columns in a pandas dataframe. diff --git a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb new file mode 100644 index 00000000..97e38401 --- /dev/null +++ b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb @@ -0,0 +1,1936 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "97e2d93c", + "metadata": {}, + "source": [ + "# Tutorial 8: Modelling tabular data with diffusion models\n", + "\n", + "This tutorial demonstrates hot to use a denoising diffusion probabilistic model (DDPM) to synthesize tabular data. The algorithm was proposed in [TabDDPM: Modelling Tabular Data with Diffusion Models](https://arxiv.org/abs/2209.15421)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "696e0157", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[KeOps] Warning : \n", + " The default C++ compiler could not be found on your system.\n", + " You need to either define the CXX environment variable or a symlink to the g++ command.\n", + " For example if g++-8 is the command you can do\n", + " import os\n", + " os.environ['CXX'] = 'g++-8'\n", + " \n", + "[KeOps] Warning : Cuda libraries were not detected on the system ; using cpu only mode\n" + ] + } + ], + "source": [ + "# stdlib\n", + "import sys\n", + "import warnings\n", + "sys.path.insert(0, '../src')\n", + "\n", + "# third party\n", + "import numpy as np\n", + "from sklearn.datasets import load_iris, load_diabetes\n", + "\n", + "# synthcity absolute\n", + "import synthcity.logger as log\n", + "from synthcity.plugins import Plugins\n", + "from synthcity.plugins.core.dataloader import GenericDataLoader\n", + "\n", + "log.add(sink=sys.stderr, level=\"INFO\")\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "54ce9a10", + "metadata": {}, + "source": [ + "## Synthesize a classification dataset\n", + "\n", + "For classification datasets, TabDDPM automatically uses the labels as the conditional variable during training. You should not provide an additional `cond` argument to the `fit` method." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "51076cdc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)target
05.13.51.40.20
14.93.01.40.20
24.73.21.30.20
34.63.11.50.20
45.03.61.40.20
\n", + "
" + ], + "text/plain": [ + " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", + "0 5.1 3.5 1.4 0.2 \n", + "1 4.9 3.0 1.4 0.2 \n", + "2 4.7 3.2 1.3 0.2 \n", + "3 4.6 3.1 1.5 0.2 \n", + "4 5.0 3.6 1.4 0.2 \n", + "\n", + " target \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "3 0 \n", + "4 0 " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Note: preprocessing data with OneHotEncoder or StandardScaler is not needed or recommended. Synthcity handles feature encoding and standardization internally.\n", + "\n", + "X, y = load_iris(return_X_y=True, as_frame=True)\n", + "X[\"target\"] = y\n", + "\n", + "loader = GenericDataLoader(X, target_column=\"target\", sensitive_columns=[])\n", + "\n", + "loader.dataframe().head()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "52397e4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 50\n", + "1 50\n", + "2 50\n", + "Name: target, dtype: int64" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.value_counts()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "cda52bea", + "metadata": {}, + "source": [ + "### Model fitting" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "3bf24be4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-03-27T15:19:24.516935+0200][30696][INFO] Step 100: MLoss: 0.0 GLoss: 0.2235 Sum: 0.2235\n", + "[2023-03-27T15:19:25.913968+0200][30696][INFO] Step 200: MLoss: 0.0 GLoss: 0.2298 Sum: 0.2298\n", + "[2023-03-27T15:19:27.191123+0200][30696][INFO] Step 300: MLoss: 0.0 GLoss: 0.2305 Sum: 0.2305\n", + "[2023-03-27T15:19:28.432055+0200][30696][INFO] Step 400: MLoss: 0.0 GLoss: 0.2273 Sum: 0.2273\n", + "[2023-03-27T15:19:29.766838+0200][30696][INFO] Step 500: MLoss: 0.0 GLoss: 0.2333 Sum: 0.2333\n", + "[2023-03-27T15:19:31.280538+0200][30696][INFO] Step 600: MLoss: 0.0 GLoss: 0.221 Sum: 0.221\n", + "[2023-03-27T15:19:33.034999+0200][30696][INFO] Step 700: MLoss: 0.0 GLoss: 0.2123 Sum: 0.2123\n", + "[2023-03-27T15:19:34.519078+0200][30696][INFO] Step 800: MLoss: 0.0 GLoss: 0.2212 Sum: 0.2212\n", + "[2023-03-27T15:19:36.020932+0200][30696][INFO] Step 900: MLoss: 0.0 GLoss: 0.2014 Sum: 0.2014\n", + "[2023-03-27T15:19:38.330664+0200][30696][INFO] Step 1000: MLoss: 0.0 GLoss: 0.2069 Sum: 0.2069\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# define the model hyper-parameters\n", + "plugin_params = dict(\n", + " is_classification = True,\n", + " n_iter = 1000, # epochs\n", + " lr = 0.002,\n", + " weight_decay = 1e-4,\n", + " batch_size = 1000,\n", + " model_type = \"mlp\", # or \"resnet\"\n", + " num_timesteps = 500, # timesteps in diffusion\n", + " n_layers_hidden = 3,\n", + " dim_hidden = 256,\n", + " dim_embed = 128,\n", + " dropout = 0.0,\n", + " # performance logging\n", + " log_interval = 10,\n", + " print_interval = 100,\n", + ")\n", + "\n", + "plugin = Plugins().get(\"ddpm\", **plugin_params)\n", + "plugin.fit(loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e1a270c9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TabDDPM(\n", + " (diffusion): GaussianMultinomialDiffusion(\n", + " (denoise_fn): MLPDiffusion(\n", + " (emb_nonlin): SiLU()\n", + " (proj): Linear(in_features=4, out_features=128, bias=True)\n", + " (time_emb): TimeStepEmbedding(\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=128, out_features=128, bias=True)\n", + " (1): SiLU()\n", + " (2): Linear(in_features=128, out_features=128, bias=True)\n", + " )\n", + " )\n", + " (label_emb): Embedding(3, 128)\n", + " (model): MLP(\n", + " (model): Sequential(\n", + " (0): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=128, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (1): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (2): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (3): Linear(in_features=256, out_features=4, bias=True)\n", + " )\n", + " (loss): MSELoss()\n", + " )\n", + " )\n", + " )\n", + " (ema_model): MLPDiffusion(\n", + " (emb_nonlin): SiLU()\n", + " (proj): Linear(in_features=4, out_features=128, bias=True)\n", + " (time_emb): TimeStepEmbedding(\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=128, out_features=128, bias=True)\n", + " (1): SiLU()\n", + " (2): Linear(in_features=128, out_features=128, bias=True)\n", + " )\n", + " )\n", + " (label_emb): Embedding(3, 128)\n", + " (model): MLP(\n", + " (model): Sequential(\n", + " (0): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=128, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (1): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (2): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (3): Linear(in_features=256, out_features=4, bias=True)\n", + " )\n", + " (loss): MSELoss()\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin.model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "49b18ada", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEGCAYAAAB1iW6ZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAAsTAAALEwEAmpwYAAAwsklEQVR4nO3dd3gU5drH8e+9m7IBAtKkI1EBpYOAKE0EpChgP4Ag2FCPiOdYjlhe5WA7iL2CekDFgiCIKGikqFTpRaqEZkKR0AnJkrL3+8cuOUkIyQJJNru5P9eVi52ZZ2fu2Qm/zD6zO4+oKsYYY4KfI9AFGGOMKRgW6MYYEyIs0I0xJkRYoBtjTIiwQDfGmBARFqgNV6pUSevUqROozRtjTFBasWLFflWtnNuygAV6nTp1WL58eaA2b4wxQUlEdp5umXW5GGNMiLBAN8aYEGGBbowxISJgfejGGJOXtLQ0EhIScLvdgS4lIFwuFzVr1iQ8PNzv51igG2OKpYSEBKKjo6lTpw4iEuhyipSqcuDAARISEoiJifH7eX51uYhIdxHZLCJxIjI8l+W1ReRnEVklImtFpOcZ1G6MMadwu91UrFixxIU5gIhQsWLFM353km+gi4gTeBfoATQA+olIgxzNngYmqWpzoC/w3hlVYYwxuSiJYX7S2ey7P2forYE4Vd2mqqnARKBPjjYKlPU9LgfsPuNK/LRyzae8NfVWMtJTC2sTxhgTlPwJ9BpAfJbpBN+8rEYAA0QkAZgJPJjbikRkiIgsF5HliYmJZ1Eu/J4wnw+PbcTtPnhWzzfGmIL08ccfM3To0ECXARTcxxb7AR+rak2gJzBBRE5Zt6p+oKotVbVl5cq5fnM1X66wKABS3IfPvlpjjAlB/gT6LqBWlumavnlZ3QVMAlDVxYALqFQQBebkCi8FgNsC3RhTyHbs2MEll1zC4MGDqVevHrfddhuzZ8+mbdu21K1bl6VLl57S/uqrr6ZJkyZ07tyZP//8E4DJkyfTqFEjmjZtSocOHQBYv349rVu3plmzZjRp0oQtW7acc73+fGxxGVBXRGLwBnlfoH+ONn8CnYGPReRSvIF+dn0q+YgKLwOA+8TRwli9MaYY+vd369mwu2D/zzeoXpZnezXMt11cXByTJ09m3LhxtGrVii+++IIFCxYwffp0XnzxRa6//vrMtg8++CCDBg1i0KBBjBs3jmHDhjFt2jRGjhxJbGwsNWrU4PDhwwCMGTOGhx56iNtuu43U1FQyMjLOeZ/yPUNX1XRgKBALbMT7aZb1IjJSRHr7mj0C3CMia4AvgcFaSIOVuiKiAXCfOFIYqzfGmGxiYmJo3LgxDoeDhg0b0rlzZ0SExo0bs2PHjmxtFy9eTP/+3vPdgQMHsmDBAgDatm3L4MGD+fDDDzOD+4orruDFF19k1KhR7Ny5k6ioqHOu1a8vFqnqTLwXO7POeybL4w1A23Ouxg9REd4z9JTUY0WxOWNMMeDPmXRhiYyMzHzscDgypx0OB+np6X6tY8yYMSxZsoQZM2Zw2WWXsWLFCvr378/ll1/OjBkz6NmzJ2PHjuXqq68+p1qD7l4urgjvpyNTTligG2OKlyuvvJKJEycC8Pnnn9O+fXsAtm7dyuWXX87IkSOpXLky8fHxbNu2jQsvvJBhw4bRp08f1q5de87bD7qv/rsifV0udoZujClm3n77be644w5Gjx5N5cqVGT9+PACPPfYYW7ZsQVXp3LkzTZs2ZdSoUUyYMIHw8HCqVq3Kk08+ec7bl0Lq6s5Xy5Yt9WwGuIiPX0zPuUN4oVYvel/9YiFUZowpDjZu3Mill14a6DICKrfXQERWqGrL3NoHXZdLlOs8ANxpxwNbiDHGFDNBF+guX6CnpKcEthBjjClmgi7QI13lAHBboBtjTDZBF+jh4aUIU7VAN8aYHIIu0AGiFNwZJwJdhjHGFCtBGeguhRQLdGOMySY4Ax3B7UkLdBnGmBJo8ODBfP3114EuI1dBGehR4iDFYwNcGGNMVkEZ6C6cdoZujCl0zz33HPXr16ddu3b069ePV155JdvyOXPm0Lx5cxo3bsydd97JiRPeruDhw4fToEEDmjRpwqOPPgrkfgvdghZ0X/0HiHI4ceu532rSGBMkfhgOe38v2HVWbQw9/nPaxcuWLWPKlCmsWbOGtLQ0WrRowWWXXZa53O12M3jwYObMmUO9evW4/fbbef/99xk4cCDffPMNmzZtQkQyb5eb2y10C1pwnqFLGCkW6MaYQrRw4UL69OmDy+UiOjqaXr16ZVu+efNmYmJiqFevHgCDBg1i3rx5lCtXDpfLxV133cXUqVMpVco7KE9ut9AtaEF5hu5yhONOt6/+G1Ni5HEmXdyEhYWxdOlS5syZw9dff80777zD3Llzc72FbsWKFQt0236doYtIdxHZLCJxIjI8l+Wvi8hq388fInK4QKvMweWIwI2nMDdhjCnh2rZty3fffYfb7SYpKYnvv/8+2/L69euzY8cO4uLiAJgwYQIdO3YkKSmJI0eO0LNnT15//XXWrFkD5H4L3YKW7xm6iDiBd4GuQAKwTESm+wa1AEBV/5ml/YNA8wKvNAuXMwK3FOYWjDElXatWrejduzdNmjShSpUqNG7cmHLlymUud7lcjB8/nltuuYX09HRatWrFfffdx8GDB+nTpw9utxtV5bXXXgNyv4VuQfOny6U1EKeq2wBEZCLQB9hwmvb9gGcLprzcRTkjcRfmBowxBnj00UcZMWIEycnJdOjQgcsuu4x77rknc3nnzp1ZtWpVtudUq1btlMGjAaZOnVro9foT6DWArO8NEoDLc2soIhcAMcDccy/t9FxOFykOQT0exBGU13WNMUFgyJAhbNiwAbfbzaBBg2jRokWgS8pTQV8U7Qt8rZr7R1BEZAgwBKB27dpnvRFXmAuAEyeO4Ioqf9brMcaYvHzxxReBLuGM+HN6uwuolWW6pm9ebvoCX55uRar6gaq2VNWWlStX9r/KHFxh3tGxU1IOnvU6jDEm1PgT6MuAuiISIyIReEN7es5GInIJUB5YXLAlnioq3Pu5Trf7cGFvyhhjgka+ga6q6cBQIBbYCExS1fUiMlJEemdp2heYqEUwSGlUeBkAUk4cKexNGWNM0PCrD11VZwIzc8x7Jsf0iIIrK2+uCG+gu08cLapNGmNMsReUHxGxQDfGFIUyZcoEuoQzEpSBHhURDYA79ViAKzHGmOIjKAPd5Qv0lLSkAFdijCkJVJXHHnuMRo0a0bhxY7766isA9uzZQ4cOHWjWrBmNGjVi/vz5ZGRkMHjw4My2r7/+epHVGZw354r0fv3WnWo36DKmJBi1dBSbDm4q0HVeUuESHm/9uF9tp06dyurVq1mzZg379++nVatWdOjQgS+++IJu3brx1FNPkZGRQXJyMqtXr2bXrl2sW7cOoNBulZub4DxDd/kCPc0C3RhT+BYsWEC/fv1wOp1UqVKFjh07smzZMlq1asX48eMZMWIEv//+O9HR0Vx44YVs27aNBx98kB9//JGyZcsWWZ1BeYYe5ToPAHd6cmALMcYUCX/PpItahw4dmDdvHjNmzGDw4ME8/PDD3H777axZs4bY2FjGjBnDpEmTGDduXJHUE6Rn6N6v+6ekpwS4EmNMSdC+fXu++uorMjIySExMZN68ebRu3ZqdO3dSpUoV7rnnHu6++25WrlzJ/v378Xg83HTTTTz//POsXLmyyOoMyjP0iIhoRNUC3RhTJG644QYWL15M06ZNERFefvllqlatyieffMLo0aMJDw+nTJkyfPrpp+zatYs77rgDj8c7ZsNLL71UZHUGZaCLw4FLwZ1uN9E1xhSepCTvJ+lEhNGjRzN69OhsywcNGsSgQYNOeV5RnpVnFZRdLgBRgDvjRKDLMMaYYiNoA92lgtuTGugyjDGm2AjaQI9CSMlIC3QZxphCVAT3+iu2zmbfgzbQXeLArRboxoQql8vFgQMHSmSoqyoHDhzA5XKd0fOC8qIogEvCcHvSA12GMaaQ1KxZk4SEBBITEwNdSkC4XC5q1qx5Rs8J4kB3ctT60I0JWeHh4cTExAS6jKAStF0uUY5wUvAEugxjjCk2/Ap0EekuIptFJE5Ehp+mza0iskFE1otIoY+s6nKE41YLdGOMOSnfLhcRcQLvAl2BBGCZiExX1Q1Z2tQFngDaquohETm/sAo+yeWMIIWSd7HEGGNOx58z9NZAnKpuU9VUYCLQJ0ebe4B3VfUQgKruK9gyT+VyRuKWwt6KMcYED38CvQYQn2U6wTcvq3pAPRFZKCK/iUj33FYkIkNEZLmILD/XK9dRvkBXj3W7GGMMFNxF0TCgLnAV0A/4UETOy9lIVT9Q1Zaq2rJy5crntEFXmIsMEdLT7AZdxhgD/gX6LqBWlumavnlZJQDTVTVNVbcDf+AN+ELjCosCIMV9sDA3Y4wxQcOfQF8G1BWRGBGJAPoC03O0mYb37BwRqYS3C2ZbwZV5KldYKQDc7iOFuRljjAka+Qa6qqYDQ4FYYCMwSVXXi8hIEentaxYLHBCRDcDPwGOqeqCwigaICi8NgPvE4cLcjDHGBA2/vimqqjOBmTnmPZPlsQIP+36KRFR4GQBSTtgZujHGQBB/U9QV4Q1094ljAa7EGGOKhyAO9GgA3KlHA1yJMcYUD0Eb6FGRJwM9KcCVGGNM8RC0ge6KKAtASqp1uRhjDARzoLvKAZCSdjzAlRhjTPEQvIEeeR4A7rTkwBZijDHFRNAGepTrPADc6fbVf2OMgSAO9EgLdGOMySZoA90ZFkGEKikZ7kCXYowxxULQBjqAS8GdfiLQZRhjTLEQ/IHusUA3xhgI8kAvheDOSAt0GcYYUywEdaC7xEGKxwLdGGMg6APdSYpaoBtjDAR9oIfh9mQEugxjjCkWgj/QsUA3xhjwM9BFpLuIbBaROBEZnsvywSKSKCKrfT93F3ypp4pyhONWT1Fsyhhjir18RywSESfwLtAV72DQy0RkuqpuyNH0K1UdWgg1npbLGU4KWpSbNMaYYsufM/TWQJyqblPVVGAi0Kdwy/KPyxGJWwJdhTHGFA/+BHoNID7LdIJvXk43ichaEflaRGrltiIRGSIiy0VkeWJi4lmUm50rzALdGGNOKqiLot8BdVS1CTAL+CS3Rqr6gaq2VNWWlStXPueNRjldpIqQkZ56zusyxphg50+g7wKynnHX9M3LpKoHVPXkd/A/Ai4rmPLy5gqLAuCE+3BRbM4YY4o1fwJ9GVBXRGJEJALoC0zP2kBEqmWZ7A1sLLgST+9koKdYoBtjTP6BrqrpwFAgFm9QT1LV9SIyUkR6+5oNE5H1IrIGGAYMLqyCs4p2lQfgyLH4fFoaY0zoy/djiwCqOhOYmWPeM1kePwE8UbCl5a9mxUtgJyTsW8uFMZ2LevPGGFOsBPU3RWtW83bVxx/8I8CVGGNM4AV1oFesUI8oj5JwLCHQpRhjTMAFdaCLw0EtnMS79we6FGOMCbigDnSAWmFliE8/HugyjDEm4II+0Gu6KpPg8ODJSA90KcYYE1BBH+i1ytYmVYR9iesCXYoxxgRU8Ad6hfoAJOxdHdhCjDEmwII/0Ks2AyD+QM67+RpjTMkS9IFetWpznKrEH9kR6FKMMSaggj7Qw8NLUdUjJCT/FehSjDEmoII+0AFqOaOITzsa6DKMMSagQiPQIyuQQFqgyzDGmIAKjUAvU4PDDuHY0V35NzbGmBAVGoF+3sUAxO9eFuBKjDEmcEIj0M9vDEB84u8BrsQYYwInJAK9ZvWWAMQf3hrgSowxJnD8CnQR6S4im0UkTkSG59HuJhFREWlZcCXmr3SZqlTwKAlJu4tys8YYU6zkG+gi4gTeBXoADYB+ItIgl3bRwEPAkoIu0h81iSAh9VAgNm2MMcWCP2forYE4Vd2mqqnARKBPLu2eA0YB7gKsz2+1IsoRnxGQTRtjTLHgT6DXALKOwpzgm5dJRFoAtVR1Rl4rEpEhIrJcRJYnJiaecbF5qVWqKnsdStoJuze6MaZkOueLoiLiAF4DHsmvrap+oKotVbVl5cqVz3XT2dQsdwEeET6YeTdHj8Tn/wRjjAkx/gT6LqBWlumavnknRQONgF9EZAfQBphe1BdGO7d6iLaUYszRdXSd2oPXvr6BtLTkoizBGGMCyp9AXwbUFZEYEYkA+gLTTy5U1SOqWklV66hqHeA3oLeqLi+Uik+jTHQ1xgxawqQ2z9M+vALjj8fx44IXi7IEY4wJqHwDXVXTgaFALLARmKSq60VkpIj0LuwCz9Sl9fvwcr+5VMxQ5u+aH+hyjDGmyIT500hVZwIzc8x75jRtrzr3ss6NwxlGO1dV5p7YS3qam7BwV6BLMsaYQhcS3xTNTftaHTnmENZumBToUowxpkiEbKBf2fROwlSZFzc9/8bGGBMCQjbQo8vWoDku5h3dEuhSjDGmSIRsoAN0qNSMLQ4Pe3avCHQpxhhT6EI70Bv2A2D+758GuBJjjCl8IR3oMRd0okYGzPtraaBLMcaYQhfSgS4OBx1K12ZJxjHcKXYnRmNMaAvpQAdoH9MNt0NY9vuEQJdijDGFKuQDvXWTQbg8yvztsYEuxRhjClXIB3qkqxyXO6OZd/xP1OMJdDnGGFNoQj7QATpUac0uJ2zf+UugSzHGmEJTIgK9fePbAZi/4csAV2KMMYWnRAR6teqXcbHHwfz9qwNdijHGFJoSEegA7ctexApNIenYnkCXYowxhaLEBHqHi3qTLsJvaz4OdCnGGFMo/Ap0EekuIptFJE5Ehuey/D4R+V1EVovIAhFpUPClnptmDfsS7VHmxc8JdCnGGFMo8g10EXEC7wI9gAZAv1wC+wtVbayqzYCX8Q4aXayEhbu4MrwC81P22scXjTEhyZ8z9NZAnKpuU9VUYCLQJ2sDVT2aZbI0oAVXYsHpUL0d+53Chj+mBboUY4wpcP4Eeg0gPst0gm9eNiLygIhsxXuGPqxgyitY7ZrdicujPLVoBHv3rg50OcYYU6AK7KKoqr6rqhcBjwNP59ZGRIaIyHIRWZ6YmFhQm/ZbhQoX816zf/KXeBg4cwDbtlt/ujEmdIhq3r0jInIFMEJVu/mmnwBQ1ZdO094BHFLVcnmtt2XLlrp8+fKzKvpcbdo8nfsWPkm6QMuwciSmp3BQ07i71jXc1PXVgNRkjDH+EJEVqtoyt2X+nKEvA+qKSIyIRAB9gWwDdYpI3SyT1wLFety3S+r3ZkLXD4mRCHakHSPK4USBsfGxZKSnBro8Y4w5K2H5NVDVdBEZCsQCTmCcqq4XkZHAclWdDgwVkS5AGnAIGFSYRReEWrWuYMLglZnTsxa8yMNbv2T+8re5qs0jAazMGGPOTr5dLoUlkF0uuUlLS6bbhNbUd5bh/UG/BbocY4zJ1bl2uZQI4eGluKl8YxZqEgkJFujGmOBjgZ7FTVc8jgCTl4wOdCnGGHPGLNCzqFq1GR0d5fjm6GZSTxwLdDnGGHNGLNBz+Nul/TjkEGYttrN0Y0xwsUDP4YoW91EjA6bt/LHA152cvL/A12mMMSdZoOfgcIbRo1x9lmkyhw5uzbZs5q/P8smMIWd1c69v5w7n8smd6D++OZ/MuIe9e1YVVMnGGANYoOfqmkYDyRBhzop3M+e5Uw7xwrYpvLJ/Me9+2/+M1nfo4FZG7/yeizOEdJRX9v9Gt9iBvDX1FtLSkgu6fGNMCWWBnotL6vaiVgb8tHtB5rzYxaM46hBaaARjj65n/Pd3+b2+N2Lv47jA6I6vMOmO1czo9D69I6ry4bFNDPzsSnbs+PWM6tu4+VsOHozLs4075VC+ozPt37+JGb88Y7cTNiZEWKDnQhwOrilXn6VZul0m7fyJmAzhv7ctpLuzPK8dWMrH399Nepo7z3WtWvsZU1P3MjC6HhdfdA0AtWu347n+s3nton7Ek86tPz/AouXv+VXbstXj6L/4Kf71Xb882z0yuSddv+7KjF+eOW2bd2c/xPCd3zB/2Zt+bbukUI+HH34dwcJl7wS6FGPOiAX6aXTzdbvMXfEemzZPZ60jjVurXUlYuIsX//YjnaQsrx5YQu8Jrfhm9mO5dp2kp7l5bsVoqmYo93Ubc8ryru2eZGrPL6iFk3/8/h5r13+VZ00JCb/x8KrXcABLcLNp8/Rc263bMJl5mkQEMHznNzz++VUcO7orW5vkpH3MTPHOe3X9+Hz/MAW7o0fi/bpPT3z8Qu6ZcDn/2jGFx9eNwZ1yqAiqM6ZgWKCfxv+6XeYzadV7RHqUXlc+CXi/VfrmgPm8VW8Q0eLkmV0/0n1Ca96ccjPbd/xCcvJ+vp71MP0+u5wtDg/DLxlIqTLn57qdKlWaMPa6iVRU4e9Ln2Pr1lm5tktO2sewWfeSAXzS+llKeZRPlr+ea9uPVrxBtEeZfuNMHjivKbFp+xn4dY9sn62P/W00yQ5hcOmL2eZUpsz917m9YH7Yv38Thw9tP2X+r7+9xnvf9Mu3Gyk36vHw119r87wWER+/mC5Te3DTpy2ZveClXLuYTriP8NF3g7lx9r2s96TQP+oCjjiEHxblelNRY4olu5dLHt6YcjMfH9tEhMI1kVV5vv/sU9qox8P8ZW8zcfOXLPIkkSFCpEc54RDqehwMrN2N6zv9B3Hk/bczPn4hA2fdixNoH1Wd4xlukjJScYoQIWHsTj/GRknn/Yb3c2WrBxg1uTcTj2/jh+6fUbVqs8z1xG39iRsWPMJ9ZRvxwA1fAjB30cs8tGUCj1S8nMHXfQTAgPEtOKbpTBu8mjs/bc1WdTPjpliioiryxaxhTP9rCc2iqtPtklto0WgAzrAIv16zXbuWMnbek1SKrMDtnUZxXvkYMtJT+XLWMN7au4AyCh90fC2z++nX317jH5vGkS5ClEfpG12Xbo0GsitxA9sObiJMnNx13fhcX7+tW2cxesHTLCSZCFXqahiNSlXnvs6vUanSJZnt/jmhHQvTD1NVHWx3Kg09TvpUa0fj2h25OKYLs357hbe3f8sep9BJyvJUt/c5v3Ijbvy4GeHi4KtBK3Pdvno8uN2HiSpVwa/XJhQlJ+3DFVUBhzPf+/yZApLXvVws0POwYdM0/rbk/wD4vOXTNGn4tzzb70/cyIylrxGftItrL+1Ps0b98w3yrDb/8T0PL3iCZJQyOCgtDjxAqnpIRxlUswu3XOM9K9+1ayk9Z93JoDJ1efjmbzLX8cTnnZiTmshP13/HeeVjMuff/0kbVnuS+L7XVA4d2cENCx7h0UpXMOjaD1i/aSp9lzxLD2cFtqcdZpPDwyUeBzvIwO0Qzs9QXm31BM0a33ba2tPT3HwWO5T3En/z1iwQpfC36LqsStrJakmjLaX4I+M4qQLvt36WlNQj3L/mDepqGP/X5v/4bOU7zExLxCOSbd3PVu/KzV3/N0xtcvJ+3vhuEJNSdlJKYWD5JqSkp7DheAKrNIVLNYJx/X8lIjKapas+4q61b/Jg+Wbc2eNDZswfwZgdM0hwetclqqgIl3qcPNL071zeYkjmdr6KHcbze3/ms8ueommjvtlqSk7axyNTerPOk8Skbp9Qrfplub4uf8T9wHuLn8eBg7JhUVSJqsztXV6ldJmqef4unAlPRjoHDmwmIyOV9IwTVKxQr0j+yBw+tJ0+03pxgUTyRu+vqFDh4kLf5klJx/bw0U8P8rcrnjjtax+qLNDPkno89Pq4KaXEedqztEB69LMOLEo7yKxbZlO6TFXi4xfTa849DCh9MY/eMi1b223b53Ljr8O40VUTlzOSL49vZU7vaZn/CZ/8/Gq+S0/k/AxleL3+dLlyOCnug8xb/h5vx03mkCjj2/6H+vWuA2DvnlW8+/Nj7Ek7QoonnX2axl6ncJVE82SXtzmecoCxi58nNv0gZRUer3M913UcScKuJdwzawiHxPt7V10djL/h28w/Pjt2/MqmP3+h9vlNqVPzSoZNuZZ1nhSmdZ9A1WrNSU9z8+AXV7FIk7jFVYsHrnmH8hUuytzP2PnP8ei2SVwfXoURt87k1gmtSFIP3/abhyuqPOA9rnv3rmLdtlg27ltF3YoN6dbu6VPOMo8n7aXz5C50iqjMS7f9nDn/wP4/eOC7W9ko6UQoNHOUYuyARac8f92Gydy75N8IUFEdHMXDfqdwc2R1nu0bm9kuIz2VibP+QZPanWjc8JbM+erxsHzNeEScNKzXO9eQTjtxnPsnXsUS/ncNpGKGMq7Tm1wY0znvX6A8qMfD4pVjaNbgVkqVqpRrm1GTevFF8nbCFSqp8G7H17gwpjMb/5jOj+s+ocH5zene4fQX5desm8iiuO/o1vTOM671nW/6Mvboei7MED696XvKlat9Rs8PZhbo5yAh4TfCwlzZujWKi9/XT6b/8pFcQRRlnS7iUg/zp8PDjz0mcn6VRqe0/8+kXnyZvJ0ohbbh5Xl1wPzMZUcO7+CH317lurZPUia6Wrbn7dm9goE/DiId+KTz+6zbPosXtk8lHagvEZSWcEo5IugZ05POVz6e7Q/f3r2rKRVVkbLlamXO2/fXOu6deRup6mH8tV/mWutJ8fGLuWn2PbR0luHdAYv491fdmZK6h2eq/e/dSk4n/7NfjosluHnlwlvp1v7//HxVs3vxq2v5OmUns66bSoUKF7Nh8zc8tngEiaKMvvQOEo8mMHLPbJ6qchV9u7+d+byVaz7l7ytf5jwVPrrmQ2rWbAPA6Ml9+DR5G580e4wWTW/PVi/ANc7zGNruOf6In8+HcV+z2eHt7w9Tpb6GcVONTtzc5VXE4UA9nszX496yDakeXRNV5e34WBzAx13GULt2u7Pa75N/GC/1OHmrx8en/P7Hxy+k95x76RNZnRub3MWwJc9xQqCKOtjq9GZKpEf5pstYatVqe8r6Dx3cyg3T+nDA6X031sDjpFfVNlzX5vFs7yxTTxwjcf9GatRonTnvyOEddPvmOi4gjC2SThNcfNDvZyIio/Pcp9QTxwgPL11gJ2Z//bWW8b8+xdrjCbzd6ysqVqpXIOvNjwV6CHt4QntWpB+irAplJYzra1x12qA7cngH135zHUccwtgG93Flqwf83s627XMZ/Msw3ECKQ2jqCeelLm/n+p/VH+lpbjyetHz/EwJ8/sP9/GffAq4gisWkcE/0pQy7cdJp23sy0vnnFx2Z6znKZRrJ+NuXnvV/4m3b59Jn3kO0l9LszUhhi8PDeR7lnVZP07RRX9Tj4f4JV7DSc5wpXcbidEYybcmrfHz4d6qogw97fpotDJOT9nHDpM64ECbftpAFK97noS0T6B1+PjWizufjw7+T4vCGXJ0M4a4LelC+dBVW71rEoqNxbHBk0MNZgWev/4pvFzzHS3/N457oSxh24+TMbcRt/Yk75z1MpMLH3cZlC8OcUpIPsmr9l7RpcW/mOwxPRjo3f3oZRzWDJIHSCm9dOZKGl9yY+bxHPmvP/LRDzLj2Kyqf35A9u1fw9E/3koZyXbV2NL+oB7fPf5RGjlJ8MPC3bK+/ejz88/P2zMs4wnuNH2TLXyv57q8lbHRkEKFK57CKtDy/OUv2rWBh2iGOO4SX69xEj44jAHhr6i18dHQjU9q9wpZdi3l8x1S6Ocvz3I3TTtvVNHvBSzy75XNaOMvyWr/ZhIeXyly2Zt1EPJ50mjcZ4NfvxNats/hy+etMTfkTBRToFVGV53K5xlYYzjnQRaQ78CbeEYs+UtX/5Fj+MHA3kA4kAneq6s681mmBHhgzf32WH3fO5o3bfj3jC1kbNk3jiUXPcF2lZtzR4wPCwl2FVGV2nox0Bk+4nFWSSu/wyjzfd3a+AX08aS/v/nAvf2v9KBdc0P6ctn/vJ61ZRAqNPWH0rtaOHm0eodx5dTKX7927mht/GEAYcNjX/d/OUYaRPT/OdnH2pHlL3uSBTR9xY0RVYt17qEMYn/T7lUhXOfbv38SUhS9wQfm6dL1yeLaL0Z6MdMbNuIu3D66gukfY41DaO8ry5m3zTjmWmzZP585FT1JG4e32o6hf99pT6jh2dBdDp/ZmpaTyaKU2DLr2QwDmLPwP/4j7nJcuuJ56Na9k6LzHOCQwsNyl3NDqYQ4c2c7AFS9xf9lG/N134T03J69BvFi7N706vZA5/9u5w3k6fgYPV2zNHdf9N3P+5i0zmLLyPb5P3skxh1ApQ+kYVZ049342SirjWj5N7eqt6P5tH9qHV+CVAfMAGPfdnbx+cBmiSg2PcHFYNE3L16NVTHcuuqAjr824g8kndlErA+KdcH14FUb2/QlxOJg253FGxM8gQ4SeYRV57Jr3qVT50mz7kZy0j7idP7N060xmHljNFoeHMFWud9Xgrnb/ZtKSlxmftCXbtZbdu5cz5tfheBSiw6IoE16G6mVrUadyE2pXb02FChef9UnGOQW6iDiBP4CuQALeMUb7qeqGLG06AUtUNVlE7geuUtU8ryBaoJszsXfPKmJXvkf/Lm8QHlm6SLd99Eg8R44mUKvWFadt89P85xkTN5mu5RvSp/XDVK+e6/+3TI991oEfMw5R3qN8lcdF1dwsXfUR/1r1BuVxMOGW2FO6yE7asGkaDy56mmMCL9TtT9d2T2YuO3gwjvu+vZktkk49DeMPSeezNiNpUO96bv2kOSnqYdrAZYSFu9i/fxMjZ97Fr54jeEQo61EiFGbcOve0H8cF7x+gQRNas0NTmd7nW8qWrcW2HXMZOP9RLhEX/x2wKNdPT7lTDrFr93Ji6nTC4Qzj0MGt9J92PcmitIs8n+9S9/FN+9e56KKugPeMf9GK91i7azFbk+L5I/Uw233dPicvet9Rpi4PXjeBD2fezftH13FP9CW4wly8fWg1bYiiadkLGXdkHZEKXSKrcjzjBEcyUtjtOZF5AR2gmYbTvUoburV8MDP4jyftpfekLlSSML4YuJRdu5dy1+z7OCJQTiFJ4LiAZrnY//j5bRnQ49TvpvjjXAP9CmCEqnbzTT8BoKq5fkBXRJoD76hqnu/FLdBNSbY/cSPP/HAHdza5j5bNBp/x85OT9yM48v00S+K+9fxjxkDvF+Mia1ClVGU8qsw8sJbd4uH1hkNoXLcXN33TiyiEBy66iX/tmMLzta6lz9XZ3oizd+9qpi15hdgDa7n34pvo3uHZfOuM2/oTt8x/mCiFFIF0EUp7lCnX5N0VlNO27XMY8MtDHHMIPZwVeHlA3rfLOHgwjhUbJvH73mW0vbBH5qeX1ONh5KQefH1iNwDXhlXiuZu/JzyyNDt2/MqoXx9ngyeJcuqgnCOM88NKUzf6AupWbkzDC7ud9lrazF+f5fEdUxlc+mJmHN1CusDYK1/g0vp9AO8Ql7t3r2Dn3pX8eXAjrev2od7FPfze/6zONdBvBrqr6t2+6YHA5ao69DTt3wH2qurzuSwbAgwBqF279mU7d+bZK2OMKQCpJ47xwtQbmZq6N3NeeY/yWvNHM/+YLFv1X+5a8zoOoKpH+G7gkmz9zOfimzn/YuneZVR1VaRqmeq0qnc9F8Zcfcbr+W3FWF5f+z6jOr1JnTodz7qe9DQ3/5lyAxVd5bm316cF8hl69Xi489PWLJcTVMpQPur4v3cQBa3IAl1EBgBDgY6qeiKv9doZujFFK+3EcXAIDgnD4Qg7pQ/3ram38uGxjYyofg03dX01QFUGr+07fuHdBc/wYPsXzvm6TV7yCnR//jTtAmplma7pm5dzI12Ap/AjzI0xRS+/aw8P9P6M9hsm0azRmd0e2njF1LmKV+rMC2gN/lxmXQbUFZEYEYkA+gLZ7grl6zcfC/RW1X0FX6YxprA5wyJo3mRAsfsCnfFfvkdOVdPxdqPEAhuBSaq6XkRGikhvX7PRQBlgsoisFpHcbwNojDGm0Ph1NUBVZwIzc8x7JsvjLgVclzHGmDNk762MMSZEWKAbY0yIsEA3xpgQYYFujDEhwgLdGGNChAW6McaECAt0Y4wJERboxhgTIizQjTEmRFigG2NMiLBAN8aYEGGBbowxIcIC3RhjQoQFujHGhAgLdGOMCREW6MYYEyL8CnQR6S4im0UkTkSG57K8g4isFJF036DSxhhjili+gS4iTuBdoAfQAOgnIg1yNPsTGAx8UdAFGmOM8Y8/Q9C1BuJUdRuAiEwE+gAbTjZQ1R2+ZZ5CqNEYY4wf/OlyqQHEZ5lO8M07YyIyRESWi8jyxMTEs1mFMcaY0yjSi6Kq+oGqtlTVlpUrVy7KTRtjTMjzJ9B3AbWyTNf0zTPGGFOM+BPoy4C6IhIjIhFAX2B64ZZljDHmTOUb6KqaDgwFYoGNwCRVXS8iI0WkN4CItBKRBOAWYKyIrC/Moo0xxpzKn0+5oKozgZk55j2T5fEyvF0xxhhjAsS+KWqMMSHCAt0YY0KEBboxxoQIC3RjjAkRFujGGBMiLNCNMSZEWKAbY0yIsEA3xpgQYYFujDEhwgLdGGNChAW6McaECAt0Y4wJERboxhgTIizQjTEmRFigG2NMiLBAN8aYEOFXoItIdxHZLCJxIjI8l+WRIvKVb/kSEalT4JUaY4zJU76BLiJO4F2gB9AA6CciDXI0uws4pKoXA68Dowq6UGOMMXnzZwi61kCcqm4DEJGJQB9gQ5Y2fYARvsdfA++IiKiqFmCtAPz7u/Vs2H20oFdrjDFFpkH1sjzbq2GBr9efLpcaQHyW6QTfvFzb+AaVPgJUzLkiERkiIstFZHliYuLZVWyMMSZXfg0SXVBU9QPgA4CWLVue1dl7YfxVM8aYUODPGfouoFaW6Zq+ebm2EZEwoBxwoCAKNMYY4x9/An0ZUFdEYkQkAugLTM/RZjowyPf4ZmBuYfSfG2OMOb18u1xUNV1EhgKxgBMYp6rrRWQksFxVpwP/BSaISBxwEG/oG2OMKUJ+9aGr6kxgZo55z2R57AZuKdjSjDHGnAn7pqgxxoQIC3RjjAkRFujGGBMiLNCNMSZESKA+XSgiicDOM3hKJWB/IZVTnJXE/S6J+wwlc79L4j7Due33BapaObcFAQv0MyUiy1W1ZaDrKGolcb9L4j5DydzvkrjPUHj7bV0uxhgTIizQjTEmRARToH8Q6AICpCTud0ncZyiZ+10S9xkKab+Dpg/dGGNM3oLpDN0YY0weLNCNMSZEBEWg5zdIdbASkVoi8rOIbBCR9SLykG9+BRGZJSJbfP+W980XEXnL9zqsFZEWgd2DsyciThFZJSLf+6ZjfAOMx/kGHI/wzQ+ZAchF5DwR+VpENonIRhG5ItSPtYj80/e7vU5EvhQRVygeaxEZJyL7RGRdlnlnfGxFZJCv/RYRGZTbtvJS7APdz0Gqg1U68IiqNgDaAA/49m04MEdV6wJzfNPgfQ3q+n6GAO8XfckF5iFgY5bpUcDrvoHGD+EdeBxCawDyN4EfVfUSoCne/Q/ZYy0iNYBhQEtVbYT39tt9Cc1j/THQPce8Mzq2IlIBeBa4HO9Yzs+e/CPgN1Ut1j/AFUBslukngCcCXVch7eu3QFdgM1DNN68asNn3eCzQL0v7zHbB9IN31Ks5wNXA94Dg/dZcWM5jjvc+/Ff4Hof52kmg9+Es9rkcsD1n7aF8rPnfWMMVfMfue6BbqB5roA6w7myPLdAPGJtlfrZ2/vwU+zN0/BukOuj53l42B5YAVVR1j2/RXqCK73GovBZvAP8CPL7pisBh9Q4wDtn3y68ByINADJAIjPd1NX0kIqUJ4WOtqruAV4A/gT14j90KQv9Yn3Smx/acj3kwBHrIE5EywBTgH6p6NOsy9f6pDpnPlorIdcA+VV0R6FqKWBjQAnhfVZsDx/nfW3AgJI91eaAP3j9m1YHSnNotUSIU1bENhkD3Z5DqoCUi4XjD/HNVneqb/ZeIVPMtrwbs880PhdeiLdBbRHYAE/F2u7wJnOcbYByy71eoDECeACSo6hLf9Nd4Az6Uj3UXYLuqJqpqGjAV7/EP9WN90pke23M+5sEQ6P4MUh2URETwjse6UVVfy7Io66Dbg/D2rZ+cf7vvKnkb4EiWt3RBQVWfUNWaqloH77Gcq6q3AT/jHWAcTt3noB+AXFX3AvEiUt83qzOwgRA+1ni7WtqISCnf7/rJfQ7pY53FmR7bWOAaESnve3dzjW+e/wJ9IcHPiw09gT+ArcBTga6nAPerHd63YWuB1b6fnnj7DecAW4DZQAVfe8H7iZ+twO94Pz0Q8P04h/2/Cvje9/hCYCkQB0wGIn3zXb7pON/yCwNd9znsbzNgue94TwPKh/qxBv4NbALWAROAyFA81sCXeK8TpOF9N3bX2Rxb4E7f/scBd5xpHfbVf2OMCRHB0OVijDHGDxboxhgTIizQjTEmRFigG2NMiLBAN8aYEGGBbko0EfmHiJQKdB3GFAT72KIp0XzfWG2pqvsDXYsx58rO0E2JISKlRWSGiKzx3Z/7Wbz3GPlZRH72tblGRBaLyEoRmey7zw4iskNEXhaR30VkqYhcHMh9MSY3FuimJOkO7FbVpuq9P/cbwG6gk6p2EpFKwNNAF1VtgfdbnQ9nef4RVW0MvON7rjHFigW6KUl+B7qKyCgRaa+qR3Isb4N3EJWFIrIa7/03Lsiy/Mss/15R2MUac6bC8m9iTGhQ1T98w331BJ4XkTk5mggwS1X7nW4Vp3lsTLFgZ+imxBCR6kCyqn4GjMZ7+9pjQLSvyW9A25P9474+93pZVvG3LP8uLpqqjfGfnaGbkqQxMFpEPHjvinc/3q6TH0Vkt68ffTDwpYhE+p7zNN47fQKUF5G1wAm8w4UZU6zYxxaN8YN9vNEEA+tyMcaYEGFn6MYYEyLsDN0YY0KEBboxxoQIC3RjjAkRFujGGBMiLNCNMSZE/D84fzE31QmjlAAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# plot training curves\n", + "plugin.loss_history.plot()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "cf5241cc", + "metadata": {}, + "source": [ + "### Data generation\n", + "\n", + "Since the model training is conditional to the labels, the data generation requires the labels as well. You can pass the labels as a `cond` argument to the `generate` method. If it is not provided, the model will randomly generate the labels following the multinomial distribution of the training labels." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a2e81779", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)target
06.4421192.9347334.3269331.3725701
16.2854122.7214405.1209012.0575472
24.6963502.0427262.8569090.7889351
35.3360192.6885334.1632831.1920511
46.0818253.2216824.6457681.5052931
55.6901652.3360884.1056301.2966071
65.3989352.7577133.8099841.1613691
77.3582703.2834286.4965902.3172382
86.5953272.5985265.8056531.4513532
95.2247182.7962243.5009151.1252481
\n", + "
" + ], + "text/plain": [ + " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", + "0 6.442119 2.934733 4.326933 1.372570 \n", + "1 6.285412 2.721440 5.120901 2.057547 \n", + "2 4.696350 2.042726 2.856909 0.788935 \n", + "3 5.336019 2.688533 4.163283 1.192051 \n", + "4 6.081825 3.221682 4.645768 1.505293 \n", + "5 5.690165 2.336088 4.105630 1.296607 \n", + "6 5.398935 2.757713 3.809984 1.161369 \n", + "7 7.358270 3.283428 6.496590 2.317238 \n", + "8 6.595327 2.598526 5.805653 1.451353 \n", + "9 5.224718 2.796224 3.500915 1.125248 \n", + "\n", + " target \n", + "0 1 \n", + "1 2 \n", + "2 1 \n", + "3 1 \n", + "4 1 \n", + "5 1 \n", + "6 1 \n", + "7 2 \n", + "8 2 \n", + "9 1 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin.generate(10)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f2d6c6cb", + "metadata": {}, + "source": [ + "### Conditional data generation" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1f55ffdb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)target
05.2009353.4104481.2944040.2501560
14.8921723.4047651.3739660.3176620
24.5464153.0013621.3792670.1460120
36.9123333.3724784.7320091.6384991
45.4792602.6232463.4961611.2651181
55.6916102.5684203.6208421.0259881
66.9353143.2469516.2097022.2368082
77.0824953.0612085.9071951.9507212
86.0660102.5531235.1930901.6390342
\n", + "
" + ], + "text/plain": [ + " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", + "0 5.200935 3.410448 1.294404 0.250156 \n", + "1 4.892172 3.404765 1.373966 0.317662 \n", + "2 4.546415 3.001362 1.379267 0.146012 \n", + "3 6.912333 3.372478 4.732009 1.638499 \n", + "4 5.479260 2.623246 3.496161 1.265118 \n", + "5 5.691610 2.568420 3.620842 1.025988 \n", + "6 6.935314 3.246951 6.209702 2.236808 \n", + "7 7.082495 3.061208 5.907195 1.950721 \n", + "8 6.066010 2.553123 5.193090 1.639034 \n", + "\n", + " target \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "3 1 \n", + "4 1 \n", + "5 1 \n", + "6 2 \n", + "7 2 \n", + "8 2 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "labels = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])\n", + "plugin.generate(len(labels), cond=labels)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "adf672a5", + "metadata": {}, + "source": [ + "## Synthesize a regression dataset\n", + "\n", + "For regression datasets, there is no conditional variable by default. The model learns the joint distribution of the whole dataset and generates new data points from it." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "13df0848", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
count4898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.000000
mean6.8547880.2782410.3341926.3914150.04577235.308085138.3606570.9940273.1882670.48984710.5142675.877909
std0.8438680.1007950.1210205.0720580.02184817.00713742.4980650.0029910.1510010.1141261.2306210.885639
min3.8000000.0800000.0000000.6000000.0090002.0000009.0000000.9871102.7200000.2200008.0000003.000000
25%6.3000000.2100000.2700001.7000000.03600023.000000108.0000000.9917233.0900000.4100009.5000005.000000
50%6.8000000.2600000.3200005.2000000.04300034.000000134.0000000.9937403.1800000.47000010.4000006.000000
75%7.3000000.3200000.3900009.9000000.05000046.000000167.0000000.9961003.2800000.55000011.4000006.000000
max14.2000001.1000001.66000065.8000000.346000289.000000440.0000001.0389803.8200001.08000014.2000009.000000
\n", + "
" + ], + "text/plain": [ + " fixed acidity volatile acidity citric acid residual sugar \\\n", + "count 4898.000000 4898.000000 4898.000000 4898.000000 \n", + "mean 6.854788 0.278241 0.334192 6.391415 \n", + "std 0.843868 0.100795 0.121020 5.072058 \n", + "min 3.800000 0.080000 0.000000 0.600000 \n", + "25% 6.300000 0.210000 0.270000 1.700000 \n", + "50% 6.800000 0.260000 0.320000 5.200000 \n", + "75% 7.300000 0.320000 0.390000 9.900000 \n", + "max 14.200000 1.100000 1.660000 65.800000 \n", + "\n", + " chlorides free sulfur dioxide total sulfur dioxide density \\\n", + "count 4898.000000 4898.000000 4898.000000 4898.000000 \n", + "mean 0.045772 35.308085 138.360657 0.994027 \n", + "std 0.021848 17.007137 42.498065 0.002991 \n", + "min 0.009000 2.000000 9.000000 0.987110 \n", + "25% 0.036000 23.000000 108.000000 0.991723 \n", + "50% 0.043000 34.000000 134.000000 0.993740 \n", + "75% 0.050000 46.000000 167.000000 0.996100 \n", + "max 0.346000 289.000000 440.000000 1.038980 \n", + "\n", + " pH sulphates alcohol quality \n", + "count 4898.000000 4898.000000 4898.000000 4898.000000 \n", + "mean 3.188267 0.489847 10.514267 5.877909 \n", + "std 0.151001 0.114126 1.230621 0.885639 \n", + "min 2.720000 0.220000 8.000000 3.000000 \n", + "25% 3.090000 0.410000 9.500000 5.000000 \n", + "50% 3.180000 0.470000 10.400000 6.000000 \n", + "75% 3.280000 0.550000 11.400000 6.000000 \n", + "max 3.820000 1.080000 14.200000 9.000000 " + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.read_csv(\"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv\", sep=\";\")\n", + "\n", + "loader = GenericDataLoader(df, target_column=\"quality\", sensitive_columns=[])\n", + "loader.dataframe().describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "14bca1cd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-03-27T18:08:18.761007+0200][38480][INFO] Step 100: MLoss: 1.2836 GLoss: 0.9867 Sum: 2.2703\n", + "[2023-03-27T18:08:24.679745+0200][38480][INFO] Step 200: MLoss: 1.2622 GLoss: 0.9409 Sum: 2.2031\n", + "[2023-03-27T18:08:30.391531+0200][38480][INFO] Step 300: MLoss: 1.2059 GLoss: 0.7669 Sum: 1.9727999999999999\n", + "[2023-03-27T18:08:36.164268+0200][38480][INFO] Step 400: MLoss: 1.1645 GLoss: 0.6393 Sum: 1.8038\n", + "[2023-03-27T18:08:41.835318+0200][38480][INFO] Step 500: MLoss: 1.1717 GLoss: 0.6158 Sum: 1.7875\n", + "[2023-03-27T18:08:47.581383+0200][38480][INFO] Step 600: MLoss: 1.1946 GLoss: 0.5384 Sum: 1.733\n", + "[2023-03-27T18:08:53.378127+0200][38480][INFO] Step 700: MLoss: 1.1343 GLoss: 0.5135 Sum: 1.6478000000000002\n", + "[2023-03-27T18:08:59.698145+0200][38480][INFO] Step 800: MLoss: 1.1168 GLoss: 0.4788 Sum: 1.5956000000000001\n", + "[2023-03-27T18:09:05.752638+0200][38480][INFO] Step 900: MLoss: 1.1034 GLoss: 0.4734 Sum: 1.5768\n", + "[2023-03-27T18:09:12.070003+0200][38480][INFO] Step 1000: MLoss: 1.142 GLoss: 0.4692 Sum: 1.6112\n", + "[2023-03-27T18:09:18.112377+0200][38480][INFO] Step 1100: MLoss: 1.1691 GLoss: 0.4602 Sum: 1.6293\n", + "[2023-03-27T18:09:25.549484+0200][38480][INFO] Step 1200: MLoss: 1.1201 GLoss: 0.4578 Sum: 1.5779\n", + "[2023-03-27T18:09:31.574874+0200][38480][INFO] Step 1300: MLoss: 1.1436 GLoss: 0.4429 Sum: 1.5865\n", + "[2023-03-27T18:09:37.672797+0200][38480][INFO] Step 1400: MLoss: 1.1093 GLoss: 0.449 Sum: 1.5583\n", + "[2023-03-27T18:09:44.149652+0200][38480][INFO] Step 1500: MLoss: 1.1468 GLoss: 0.4347 Sum: 1.5815000000000001\n", + "[2023-03-27T18:09:49.923915+0200][38480][INFO] Step 1600: MLoss: 1.1545 GLoss: 0.4313 Sum: 1.5858\n", + "[2023-03-27T18:09:55.733558+0200][38480][INFO] Step 1700: MLoss: 1.102 GLoss: 0.4305 Sum: 1.5325000000000002\n", + "[2023-03-27T18:10:03.367053+0200][38480][INFO] Step 1800: MLoss: 1.0953 GLoss: 0.4267 Sum: 1.522\n", + "[2023-03-27T18:10:10.533359+0200][38480][INFO] Step 1900: MLoss: 1.1247 GLoss: 0.4223 Sum: 1.5470000000000002\n", + "[2023-03-27T18:10:17.355705+0200][38480][INFO] Step 2000: MLoss: 1.2767 GLoss: 0.4266 Sum: 1.7033\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# define the model hyper-parameters\n", + "plugin_params.update(\n", + " is_classification = False,\n", + " n_iter = 500, # epochs\n", + " lr = 5e-4,\n", + " weight_decay = 1e-4,\n", + " batch_size = 1250,\n", + " n_layers_hidden = 3,\n", + " dim_hidden = 256,\n", + " num_timesteps = 100, # timesteps in diffusion\n", + ")\n", + "plugin = Plugins().get(\"ddpm\", **plugin_params)\n", + "plugin.fit(loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "83064f94", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plugin.loss_history.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "af9d6df1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
053.753993-2.4752390.404968406.8983861.788962450.0737241221.5510775.69717712.45137714.83544581.5156960.0
1241.769932-33.9059337.4401881722.53605348.0345561312.2644572141.11936831.25907483.65474928.591759489.1726743.0
225.3449040.769463-11.237007-335.794326-3.595284-234.179124382.9075157.63768417.7483003.38029673.7010481.0
315.635557-28.371864-19.808469800.08844661.404066-596.053591-1749.79750528.376345-71.868790-14.556346-38.3151791.0
4-0.796959-8.546869-4.726590128.3430281.083628-288.3521041184.6802738.08150023.0128282.16859736.6728400.0
5-31.203381-39.052177-57.6510321269.158981-22.793850101.490751-661.9978235.01273819.61582226.791456-63.7736783.0
6-120.526480-49.314650-67.642982650.13681665.155843598.106999-3468.7530373.75056652.556860-108.310847-91.8163103.0
713.172627-7.196406-20.153565746.262383-30.8466881592.8153971610.699379-15.57666027.31969245.376814135.8714220.0
\n", + "
" + ], + "text/plain": [ + " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", + "0 53.753993 -2.475239 0.404968 406.898386 1.788962 \n", + "1 241.769932 -33.905933 7.440188 1722.536053 48.034556 \n", + "2 25.344904 0.769463 -11.237007 -335.794326 -3.595284 \n", + "3 15.635557 -28.371864 -19.808469 800.088446 61.404066 \n", + "4 -0.796959 -8.546869 -4.726590 128.343028 1.083628 \n", + "5 -31.203381 -39.052177 -57.651032 1269.158981 -22.793850 \n", + "6 -120.526480 -49.314650 -67.642982 650.136816 65.155843 \n", + "7 13.172627 -7.196406 -20.153565 746.262383 -30.846688 \n", + "\n", + " free sulfur dioxide total sulfur dioxide density pH \\\n", + "0 450.073724 1221.551077 5.697177 12.451377 \n", + "1 1312.264457 2141.119368 31.259074 83.654749 \n", + "2 -234.179124 382.907515 7.637684 17.748300 \n", + "3 -596.053591 -1749.797505 28.376345 -71.868790 \n", + "4 -288.352104 1184.680273 8.081500 23.012828 \n", + "5 101.490751 -661.997823 5.012738 19.615822 \n", + "6 598.106999 -3468.753037 3.750566 52.556860 \n", + "7 1592.815397 1610.699379 -15.576660 27.319692 \n", + "\n", + " sulphates alcohol quality \n", + "0 14.835445 81.515696 0.0 \n", + "1 28.591759 489.172674 3.0 \n", + "2 3.380296 73.701048 1.0 \n", + "3 -14.556346 -38.315179 1.0 \n", + "4 2.168597 36.672840 0.0 \n", + "5 26.791456 -63.773678 3.0 \n", + "6 -108.310847 -91.816310 3.0 \n", + "7 45.376814 135.871422 0.0 " + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin.model.generate(8)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "be62c2f0", + "metadata": {}, + "source": [ + "### Conditional data generation\n", + "\n", + "A conditional variable `cond` can be provided to the `fit` method. It can be either a column name in the dataset or a custom array. The model will then learn the conditional distribution of the dataset given `cond`. In this case, an array must be provided as the `cond` argument of the `generate` method." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "56a1fc7e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-03-27T18:03:45.005934+0200][38480][INFO] Step 100: MLoss: 0.9066 GLoss: 1.0013 Sum: 1.9079000000000002\n", + "[2023-03-27T18:03:51.387087+0200][38480][INFO] Step 200: MLoss: 0.4735 GLoss: 1.0112 Sum: 1.4847000000000001\n", + "[2023-03-27T18:03:59.107456+0200][38480][INFO] Step 300: MLoss: 0.4567 GLoss: 1.001 Sum: 1.4577\n", + "[2023-03-27T18:04:05.835508+0200][38480][INFO] Step 400: MLoss: 0.2715 GLoss: 0.9856 Sum: 1.2571\n", + "[2023-03-27T18:04:12.739590+0200][38480][INFO] Step 500: MLoss: 0.2193 GLoss: 0.9046 Sum: 1.1239\n", + "[2023-03-27T18:04:19.417762+0200][38480][INFO] Step 600: MLoss: 0.0143 GLoss: 0.8463 Sum: 0.8606\n", + "[2023-03-27T18:04:26.022729+0200][38480][INFO] Step 700: MLoss: 0.0048 GLoss: 0.7509 Sum: 0.7557\n", + "[2023-03-27T18:04:32.757598+0200][38480][INFO] Step 800: MLoss: 0.0083 GLoss: 0.7102 Sum: 0.7185\n", + "[2023-03-27T18:04:39.550873+0200][38480][INFO] Step 900: MLoss: 0.0029 GLoss: 0.675 Sum: 0.6779000000000001\n", + "[2023-03-27T18:04:46.573464+0200][38480][INFO] Step 1000: MLoss: 0.0039 GLoss: 0.6414 Sum: 0.6453\n", + "[2023-03-27T18:04:53.438631+0200][38480][INFO] Step 1100: MLoss: 0.003 GLoss: 0.6046 Sum: 0.6076\n", + "[2023-03-27T18:05:01.283222+0200][38480][INFO] Step 1200: MLoss: 0.0013 GLoss: 0.6297 Sum: 0.631\n", + "[2023-03-27T18:05:08.559280+0200][38480][INFO] Step 1300: MLoss: 0.0012 GLoss: 0.5479 Sum: 0.5491\n", + "[2023-03-27T18:05:15.536738+0200][38480][INFO] Step 1400: MLoss: 0.0067 GLoss: 0.5275 Sum: 0.5342\n", + "[2023-03-27T18:05:22.391711+0200][38480][INFO] Step 1500: MLoss: 0.0007 GLoss: 0.5252 Sum: 0.5259\n", + "[2023-03-27T18:05:29.285959+0200][38480][INFO] Step 1600: MLoss: 0.0018 GLoss: 0.5017 Sum: 0.5035000000000001\n", + "[2023-03-27T18:05:36.288634+0200][38480][INFO] Step 1700: MLoss: 0.0012 GLoss: 0.5013 Sum: 0.5025\n", + "[2023-03-27T18:05:43.485831+0200][38480][INFO] Step 1800: MLoss: 0.0009 GLoss: 0.4927 Sum: 0.49360000000000004\n", + "[2023-03-27T18:05:50.629387+0200][38480][INFO] Step 1900: MLoss: 0.0009 GLoss: 0.4931 Sum: 0.494\n", + "[2023-03-27T18:05:58.709478+0200][38480][INFO] Step 2000: MLoss: 0.0006 GLoss: 0.4864 Sum: 0.487\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin.fit(loader, cond='quality')" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "3fcb9493", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plugin.loss_history.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "2ea981cd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([3, 4, 5, 6, 7, 8, 9])" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outcome = np.array([3, 4, 5, 6, 7, 8, 9])\n", + "outcome" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "bbd33233", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-03-27T18:05:59.734678+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:05:59.737612+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:00.157952+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:00.160095+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:00.484737+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:00.485757+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:00.786487+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:00.788466+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:01.100020+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:01.102261+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:01.460078+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:01.462163+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:01.805568+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:01.807568+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:02.183897+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:02.185904+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:02.569835+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:02.571874+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:03.033272+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:03.035146+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:03.579187+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:03.582312+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:04.128201+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:04.131216+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:04.909594+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:04.912681+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:05.506491+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:05.509890+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:06.285555+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:06.287092+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:06.748144+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:06.751143+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:07.239364+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:07.241364+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:07.833861+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:07.835862+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:08.270020+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:08.273103+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:08.579762+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:08.581664+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:08.995746+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:08.996750+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:09.387130+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:09.389133+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:09.913255+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:09.915271+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:10.414403+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:10.417511+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:10.986099+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:10.988092+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:11.384006+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 0. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:11.699391+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:11.700392+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:12.138923+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:12.140923+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:12.604077+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:12.606674+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:12.997333+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:12.999663+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:13.547570+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:13.550541+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:13.954516+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:13.956516+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:14.445116+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:14.452112+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:14.829066+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:14.832071+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:15.312829+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:15.315831+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:15.757355+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:15.759926+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:16.143136+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:16.145134+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:16.560027+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:16.562025+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:16.861918+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:16.863918+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:17.183558+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 0. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:17.637582+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:17.640281+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:17.997687+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:17.998689+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:18.345381+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:18.347383+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:18.676026+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:18.678607+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:19.007549+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:19.010506+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:19.346424+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:19.348531+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:19.696186+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:19.697186+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:20.073809+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:20.077249+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:06:20.472414+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:20.475399+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:20.942265+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:20.944268+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:21.302342+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:21.304314+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:21.665401+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:21.666987+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:22.067854+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:22.069371+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:22.392718+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:22.395677+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:22.716515+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:22.717515+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:23.047434+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:23.049434+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:23.399152+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:23.401151+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:23.745625+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:23.747624+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:24.098540+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:24.099540+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:24.421854+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:24.422839+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:24.738758+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:24.739667+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:25.058648+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:25.060550+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:25.399681+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:25.401599+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:25.737806+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:25.738793+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:26.069784+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:26.071290+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:26.416549+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:26.418554+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:26.801542+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:26.803529+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:27.139240+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:27.141225+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:27.488070+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:27.490052+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:27.823788+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:27.824814+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:28.163857+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:28.166838+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:28.499341+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:28.501342+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:28.823408+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:28.825499+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:29.125222+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:29.128129+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:29.492914+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:29.496428+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:29.833079+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:29.835167+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:30.217776+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:30.219777+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:30.536676+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:30.538667+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:30.861816+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:30.863812+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:31.177127+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:31.180126+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:31.606751+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:31.607978+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:31.949768+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:31.951785+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:32.289786+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:32.291671+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:32.629730+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 0. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:32.942556+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:32.945560+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 1. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:32.949202+0200][38480][INFO] [residual sugar] quality loss for constraints ge = 0.6. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:33.286281+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:33.287280+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:33.620445+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:33.622445+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:33.945427+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:33.947494+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:34.298877+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:34.300955+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:34.618880+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:34.620789+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:34.959467+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:34.961383+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:35.296247+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:35.298303+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:35.763113+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:35.765112+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:36.178981+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:36.181338+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:36.555008+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:36.555991+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:36.880093+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:36.881104+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:37.299044+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:37.301205+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:37.708557+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:37.711544+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:38.087165+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:38.089166+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:38.482563+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:38.483562+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:38.941184+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:38.942166+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:39.291995+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:39.294883+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:39.642425+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:39.645485+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:39.965926+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:39.967445+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:40.280863+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:40.281866+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:40.567363+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:40.569362+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:40.863820+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:40.865893+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:41.406311+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:41.409435+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:42.003319+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:42.006307+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:42.470804+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:42.471786+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:42.768361+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:42.770360+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:43.102405+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:43.105718+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:06:43.426329+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:43.429478+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:43.757004+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:43.759124+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:44.083407+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:44.084408+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:44.400443+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:44.401428+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:44.706402+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:44.708999+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:45.018534+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:45.019535+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:45.519397+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:45.521407+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:45.921477+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:45.922985+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:46.265432+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:46.267956+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:46.717722+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:46.719733+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:47.062693+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:47.064691+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:47.417125+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:47.418108+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:47.758309+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:47.760595+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:48.135817+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:48.137801+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:48.458595+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:48.460608+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:48.754069+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:48.756024+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:49.049862+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:49.051462+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:49.350537+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:49.352536+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:49.766318+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:49.769390+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:50.276306+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:50.279351+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:50.665664+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:50.666685+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:51.009462+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:51.012707+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:51.308313+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:51.309313+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:51.637138+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:51.639120+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:51.979944+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:51.980946+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:52.297063+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:52.298062+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:52.625280+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:52.628303+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:52.938341+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:52.939345+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:53.233624+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:53.235624+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:53.550284+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:53.552284+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:53.859100+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:53.863229+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:54.227895+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:54.229895+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:54.534473+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:54.536457+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:54.835486+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:54.837487+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:55.132594+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:55.134593+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:55.465635+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:55.467185+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:55.807745+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:55.810517+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:56.300336+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:56.302923+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:56.604424+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:56.605423+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:56.898530+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:56.900544+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:57.205520+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:57.206520+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:57.503438+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:57.505437+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:57.819558+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:57.821581+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:06:58.160813+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:58.163417+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:58.462315+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:58.463303+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:58.815614+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:58.817596+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:59.129940+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:59.130934+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:59.577632+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:59.580621+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:59.909210+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:59.910211+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:00.263906+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:00.265906+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:00.573175+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:00.574177+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:00.866210+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:00.868793+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:01.205344+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:01.207327+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:01.606906+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:01.608906+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:02.102300+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:02.105211+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:02.503969+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:02.506485+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:02.906864+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:02.908864+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:03.298141+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:03.300142+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:03.619687+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:03.621670+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:03.942307+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:03.946964+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:04.383317+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:04.384318+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:04.685032+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:04.687584+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:04.985829+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:04.986829+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:07:05.266858+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:05.269157+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:05.580166+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:05.582149+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:05.889785+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:05.892186+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:06.211209+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:06.213722+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:06.513714+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:06.515729+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:06.832167+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:06.834177+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:07.144798+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:07.146797+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:07.479304+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:07.481835+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:07.846999+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:07.848997+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:08.195789+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:08.197813+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 1. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:08.200813+0200][38480][INFO] [residual sugar] quality loss for constraints le = 65.8. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:08.691113+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:08.694249+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:09.231893+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:09.235438+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:09.713446+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:09.716162+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:07:10.805837+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:10.809012+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:11.446846+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:11.450600+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:12.110297+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:12.114136+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:12.587219+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:12.589217+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:13.186604+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:13.188628+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:13.765722+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:13.767730+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:14.222493+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:14.225273+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:14.581621+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:14.582622+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:14.916005+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:14.917005+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:15.232768+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:15.233771+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:15.587426+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:15.589426+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:15.937914+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:15.939914+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:16.341209+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:16.343228+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:16.667291+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:16.669292+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:16.989838+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:16.991912+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:17.306825+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:17.308797+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:17.659105+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:17.661131+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:18.018946+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:18.019947+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:18.393086+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:18.396311+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:07:18.830421+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:18.833527+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:19.232926+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:19.236012+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:19.669845+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:19.672139+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:20.034654+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:20.035654+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:20.365288+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:20.367291+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:20.677852+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:20.680692+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:20.988636+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:20.990732+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:21.326922+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:21.329905+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:21.682149+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:21.684150+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:22.042272+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:22.043272+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:22.417916+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:22.418916+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:22.749237+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:22.751237+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:23.090475+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:23.091459+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:23.470508+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:23.473305+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:23.821072+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:23.823567+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 1. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:23.827191+0200][38480][INFO] [residual sugar] quality loss for constraints ge = 0.6. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:24.193607+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:24.194590+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:24.532529+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:24.534525+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:24.876586+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:24.878585+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:25.216076+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:25.217076+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:25.599528+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:25.601333+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:26.159795+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:26.161982+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:26.541276+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:26.542274+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:26.869887+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:26.872038+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:27.183814+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:27.186139+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:27.522592+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:27.524574+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:27.885528+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:27.886547+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:28.236311+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:28.237310+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:07:28.569622+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:28.571622+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:28.889372+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:28.890372+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:29.200272+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:29.202272+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:29.533137+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:29.535216+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:29.936280+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:29.939026+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:30.369796+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:30.371797+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:30.718054+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:30.720128+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:31.139806+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:31.140809+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mplugin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moutcome\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\plugin.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, count, constraints, random_state, **kwargs)\u001b[0m\n\u001b[0;32m 337\u001b[0m \u001b[0msyn_schema\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mSchema\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_constraints\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgen_constraints\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 338\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 339\u001b[1;33m \u001b[0mX_syn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_generate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msyn_schema\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msyn_schema\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 340\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 341\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mX_syn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_tabular\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_ddpm.py\u001b[0m in \u001b[0;36m_generate\u001b[1;34m(self, count, syn_schema, **kwargs)\u001b[0m\n\u001b[0;32m 246\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 247\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 248\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_safe_generate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msyn_schema\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 249\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 250\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\plugin.py\u001b[0m in \u001b[0;36m_safe_generate\u001b[1;34m(self, gen_cbk, count, syn_schema, **kwargs)\u001b[0m\n\u001b[0;32m 391\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mit\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msampling_patience\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 392\u001b[0m \u001b[1;31m# sample\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 393\u001b[1;33m \u001b[0miter_samples\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgen_cbk\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 394\u001b[0m iter_samples_df = pd.DataFrame(\n\u001b[0;32m 395\u001b[0m \u001b[0miter_samples\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining_schema\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_ddpm.py\u001b[0m in \u001b[0;36mcallback\u001b[1;34m(count)\u001b[0m\n\u001b[0;32m 241\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 242\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# type: ignore\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 243\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcond\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 244\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_classification\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 245\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_iloc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\__init__.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, count, cond)\u001b[0m\n\u001b[0;32m 211\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcond\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 212\u001b[0m \u001b[0mcond\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcond\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 213\u001b[1;33m \u001b[0msample\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdiffusion\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 214\u001b[0m \u001b[0msample\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msample\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_col_perm\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0msample\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\gaussian_multinomial_diffsuion.py\u001b[0m in \u001b[0;36msample_all\u001b[1;34m(self, num_samples, cond, max_batch_size, ddim)\u001b[0m\n\u001b[0;32m 951\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 952\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mb\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mbs\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 953\u001b[1;33m \u001b[0msample\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msample_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 954\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0many\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msample\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 955\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"found NaNs in sample\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\autograd\\grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 27\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 28\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 29\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\gaussian_multinomial_diffsuion.py\u001b[0m in \u001b[0;36msample\u001b[1;34m(self, num_samples, cond)\u001b[0m\n\u001b[0;32m 918\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"Sample timestep {i:4d}\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"\\r\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 919\u001b[0m \u001b[0mt\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfull\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 920\u001b[1;33m model_out = self.denoise_fn(\n\u001b[0m\u001b[0;32m 921\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mz_norm\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlog_z\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcond\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 922\u001b[0m )\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\modules.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, t, y)\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[0memb\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0memb_nonlin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlabel_emb\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0memb\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 113\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\mlp.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m 398\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mvalidate_arguments\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marbitrary_types_allowed\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 399\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 400\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 401\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 402\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_train_epoch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloader\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mfloat\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 202\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 203\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 204\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 205\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 206\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\mlp.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mvalidate_arguments\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marbitrary_types_allowed\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 114\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 116\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 202\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 203\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 204\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 205\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 206\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 114\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 116\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "plugin.generate(len(outcome), cond=outcome)" + ] + }, + { + "cell_type": "markdown", + "id": "ea5abc50", + "metadata": {}, + "source": [ + "## Congratulations!\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!\n", + "\n", + "### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub\n", + "\n", + "- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.\n", + "\n", + "\n", + "### Checkout other projects from vanderschaarlab\n", + "- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)\n", + "- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From bcdce4b793cc5ad38b1687b28d5821a4bdea579e Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 27 Mar 2023 20:53:58 +0200 Subject: [PATCH 28/95] add TabDDPM tutorial --- ...al8_tabular_modelling_with_diffusion.ipynb | 1936 +++++++++++++++++ 1 file changed, 1936 insertions(+) create mode 100644 tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb diff --git a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb new file mode 100644 index 00000000..97e38401 --- /dev/null +++ b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb @@ -0,0 +1,1936 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "97e2d93c", + "metadata": {}, + "source": [ + "# Tutorial 8: Modelling tabular data with diffusion models\n", + "\n", + "This tutorial demonstrates hot to use a denoising diffusion probabilistic model (DDPM) to synthesize tabular data. The algorithm was proposed in [TabDDPM: Modelling Tabular Data with Diffusion Models](https://arxiv.org/abs/2209.15421)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "696e0157", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[KeOps] Warning : \n", + " The default C++ compiler could not be found on your system.\n", + " You need to either define the CXX environment variable or a symlink to the g++ command.\n", + " For example if g++-8 is the command you can do\n", + " import os\n", + " os.environ['CXX'] = 'g++-8'\n", + " \n", + "[KeOps] Warning : Cuda libraries were not detected on the system ; using cpu only mode\n" + ] + } + ], + "source": [ + "# stdlib\n", + "import sys\n", + "import warnings\n", + "sys.path.insert(0, '../src')\n", + "\n", + "# third party\n", + "import numpy as np\n", + "from sklearn.datasets import load_iris, load_diabetes\n", + "\n", + "# synthcity absolute\n", + "import synthcity.logger as log\n", + "from synthcity.plugins import Plugins\n", + "from synthcity.plugins.core.dataloader import GenericDataLoader\n", + "\n", + "log.add(sink=sys.stderr, level=\"INFO\")\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "54ce9a10", + "metadata": {}, + "source": [ + "## Synthesize a classification dataset\n", + "\n", + "For classification datasets, TabDDPM automatically uses the labels as the conditional variable during training. You should not provide an additional `cond` argument to the `fit` method." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "51076cdc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)target
05.13.51.40.20
14.93.01.40.20
24.73.21.30.20
34.63.11.50.20
45.03.61.40.20
\n", + "
" + ], + "text/plain": [ + " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", + "0 5.1 3.5 1.4 0.2 \n", + "1 4.9 3.0 1.4 0.2 \n", + "2 4.7 3.2 1.3 0.2 \n", + "3 4.6 3.1 1.5 0.2 \n", + "4 5.0 3.6 1.4 0.2 \n", + "\n", + " target \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "3 0 \n", + "4 0 " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Note: preprocessing data with OneHotEncoder or StandardScaler is not needed or recommended. Synthcity handles feature encoding and standardization internally.\n", + "\n", + "X, y = load_iris(return_X_y=True, as_frame=True)\n", + "X[\"target\"] = y\n", + "\n", + "loader = GenericDataLoader(X, target_column=\"target\", sensitive_columns=[])\n", + "\n", + "loader.dataframe().head()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "52397e4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 50\n", + "1 50\n", + "2 50\n", + "Name: target, dtype: int64" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.value_counts()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "cda52bea", + "metadata": {}, + "source": [ + "### Model fitting" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "3bf24be4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-03-27T15:19:24.516935+0200][30696][INFO] Step 100: MLoss: 0.0 GLoss: 0.2235 Sum: 0.2235\n", + "[2023-03-27T15:19:25.913968+0200][30696][INFO] Step 200: MLoss: 0.0 GLoss: 0.2298 Sum: 0.2298\n", + "[2023-03-27T15:19:27.191123+0200][30696][INFO] Step 300: MLoss: 0.0 GLoss: 0.2305 Sum: 0.2305\n", + "[2023-03-27T15:19:28.432055+0200][30696][INFO] Step 400: MLoss: 0.0 GLoss: 0.2273 Sum: 0.2273\n", + "[2023-03-27T15:19:29.766838+0200][30696][INFO] Step 500: MLoss: 0.0 GLoss: 0.2333 Sum: 0.2333\n", + "[2023-03-27T15:19:31.280538+0200][30696][INFO] Step 600: MLoss: 0.0 GLoss: 0.221 Sum: 0.221\n", + "[2023-03-27T15:19:33.034999+0200][30696][INFO] Step 700: MLoss: 0.0 GLoss: 0.2123 Sum: 0.2123\n", + "[2023-03-27T15:19:34.519078+0200][30696][INFO] Step 800: MLoss: 0.0 GLoss: 0.2212 Sum: 0.2212\n", + "[2023-03-27T15:19:36.020932+0200][30696][INFO] Step 900: MLoss: 0.0 GLoss: 0.2014 Sum: 0.2014\n", + "[2023-03-27T15:19:38.330664+0200][30696][INFO] Step 1000: MLoss: 0.0 GLoss: 0.2069 Sum: 0.2069\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# define the model hyper-parameters\n", + "plugin_params = dict(\n", + " is_classification = True,\n", + " n_iter = 1000, # epochs\n", + " lr = 0.002,\n", + " weight_decay = 1e-4,\n", + " batch_size = 1000,\n", + " model_type = \"mlp\", # or \"resnet\"\n", + " num_timesteps = 500, # timesteps in diffusion\n", + " n_layers_hidden = 3,\n", + " dim_hidden = 256,\n", + " dim_embed = 128,\n", + " dropout = 0.0,\n", + " # performance logging\n", + " log_interval = 10,\n", + " print_interval = 100,\n", + ")\n", + "\n", + "plugin = Plugins().get(\"ddpm\", **plugin_params)\n", + "plugin.fit(loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e1a270c9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TabDDPM(\n", + " (diffusion): GaussianMultinomialDiffusion(\n", + " (denoise_fn): MLPDiffusion(\n", + " (emb_nonlin): SiLU()\n", + " (proj): Linear(in_features=4, out_features=128, bias=True)\n", + " (time_emb): TimeStepEmbedding(\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=128, out_features=128, bias=True)\n", + " (1): SiLU()\n", + " (2): Linear(in_features=128, out_features=128, bias=True)\n", + " )\n", + " )\n", + " (label_emb): Embedding(3, 128)\n", + " (model): MLP(\n", + " (model): Sequential(\n", + " (0): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=128, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (1): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (2): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (3): Linear(in_features=256, out_features=4, bias=True)\n", + " )\n", + " (loss): MSELoss()\n", + " )\n", + " )\n", + " )\n", + " (ema_model): MLPDiffusion(\n", + " (emb_nonlin): SiLU()\n", + " (proj): Linear(in_features=4, out_features=128, bias=True)\n", + " (time_emb): TimeStepEmbedding(\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=128, out_features=128, bias=True)\n", + " (1): SiLU()\n", + " (2): Linear(in_features=128, out_features=128, bias=True)\n", + " )\n", + " )\n", + " (label_emb): Embedding(3, 128)\n", + " (model): MLP(\n", + " (model): Sequential(\n", + " (0): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=128, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (1): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (2): LinearLayer(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU()\n", + " )\n", + " )\n", + " (3): Linear(in_features=256, out_features=4, bias=True)\n", + " )\n", + " (loss): MSELoss()\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin.model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "49b18ada", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# plot training curves\n", + "plugin.loss_history.plot()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "cf5241cc", + "metadata": {}, + "source": [ + "### Data generation\n", + "\n", + "Since the model training is conditional to the labels, the data generation requires the labels as well. You can pass the labels as a `cond` argument to the `generate` method. If it is not provided, the model will randomly generate the labels following the multinomial distribution of the training labels." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a2e81779", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)target
06.4421192.9347334.3269331.3725701
16.2854122.7214405.1209012.0575472
24.6963502.0427262.8569090.7889351
35.3360192.6885334.1632831.1920511
46.0818253.2216824.6457681.5052931
55.6901652.3360884.1056301.2966071
65.3989352.7577133.8099841.1613691
77.3582703.2834286.4965902.3172382
86.5953272.5985265.8056531.4513532
95.2247182.7962243.5009151.1252481
\n", + "
" + ], + "text/plain": [ + " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", + "0 6.442119 2.934733 4.326933 1.372570 \n", + "1 6.285412 2.721440 5.120901 2.057547 \n", + "2 4.696350 2.042726 2.856909 0.788935 \n", + "3 5.336019 2.688533 4.163283 1.192051 \n", + "4 6.081825 3.221682 4.645768 1.505293 \n", + "5 5.690165 2.336088 4.105630 1.296607 \n", + "6 5.398935 2.757713 3.809984 1.161369 \n", + "7 7.358270 3.283428 6.496590 2.317238 \n", + "8 6.595327 2.598526 5.805653 1.451353 \n", + "9 5.224718 2.796224 3.500915 1.125248 \n", + "\n", + " target \n", + "0 1 \n", + "1 2 \n", + "2 1 \n", + "3 1 \n", + "4 1 \n", + "5 1 \n", + "6 1 \n", + "7 2 \n", + "8 2 \n", + "9 1 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin.generate(10)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f2d6c6cb", + "metadata": {}, + "source": [ + "### Conditional data generation" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1f55ffdb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)target
05.2009353.4104481.2944040.2501560
14.8921723.4047651.3739660.3176620
24.5464153.0013621.3792670.1460120
36.9123333.3724784.7320091.6384991
45.4792602.6232463.4961611.2651181
55.6916102.5684203.6208421.0259881
66.9353143.2469516.2097022.2368082
77.0824953.0612085.9071951.9507212
86.0660102.5531235.1930901.6390342
\n", + "
" + ], + "text/plain": [ + " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", + "0 5.200935 3.410448 1.294404 0.250156 \n", + "1 4.892172 3.404765 1.373966 0.317662 \n", + "2 4.546415 3.001362 1.379267 0.146012 \n", + "3 6.912333 3.372478 4.732009 1.638499 \n", + "4 5.479260 2.623246 3.496161 1.265118 \n", + "5 5.691610 2.568420 3.620842 1.025988 \n", + "6 6.935314 3.246951 6.209702 2.236808 \n", + "7 7.082495 3.061208 5.907195 1.950721 \n", + "8 6.066010 2.553123 5.193090 1.639034 \n", + "\n", + " target \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "3 1 \n", + "4 1 \n", + "5 1 \n", + "6 2 \n", + "7 2 \n", + "8 2 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "labels = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])\n", + "plugin.generate(len(labels), cond=labels)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "adf672a5", + "metadata": {}, + "source": [ + "## Synthesize a regression dataset\n", + "\n", + "For regression datasets, there is no conditional variable by default. The model learns the joint distribution of the whole dataset and generates new data points from it." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "13df0848", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
count4898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.0000004898.000000
mean6.8547880.2782410.3341926.3914150.04577235.308085138.3606570.9940273.1882670.48984710.5142675.877909
std0.8438680.1007950.1210205.0720580.02184817.00713742.4980650.0029910.1510010.1141261.2306210.885639
min3.8000000.0800000.0000000.6000000.0090002.0000009.0000000.9871102.7200000.2200008.0000003.000000
25%6.3000000.2100000.2700001.7000000.03600023.000000108.0000000.9917233.0900000.4100009.5000005.000000
50%6.8000000.2600000.3200005.2000000.04300034.000000134.0000000.9937403.1800000.47000010.4000006.000000
75%7.3000000.3200000.3900009.9000000.05000046.000000167.0000000.9961003.2800000.55000011.4000006.000000
max14.2000001.1000001.66000065.8000000.346000289.000000440.0000001.0389803.8200001.08000014.2000009.000000
\n", + "
" + ], + "text/plain": [ + " fixed acidity volatile acidity citric acid residual sugar \\\n", + "count 4898.000000 4898.000000 4898.000000 4898.000000 \n", + "mean 6.854788 0.278241 0.334192 6.391415 \n", + "std 0.843868 0.100795 0.121020 5.072058 \n", + "min 3.800000 0.080000 0.000000 0.600000 \n", + "25% 6.300000 0.210000 0.270000 1.700000 \n", + "50% 6.800000 0.260000 0.320000 5.200000 \n", + "75% 7.300000 0.320000 0.390000 9.900000 \n", + "max 14.200000 1.100000 1.660000 65.800000 \n", + "\n", + " chlorides free sulfur dioxide total sulfur dioxide density \\\n", + "count 4898.000000 4898.000000 4898.000000 4898.000000 \n", + "mean 0.045772 35.308085 138.360657 0.994027 \n", + "std 0.021848 17.007137 42.498065 0.002991 \n", + "min 0.009000 2.000000 9.000000 0.987110 \n", + "25% 0.036000 23.000000 108.000000 0.991723 \n", + "50% 0.043000 34.000000 134.000000 0.993740 \n", + "75% 0.050000 46.000000 167.000000 0.996100 \n", + "max 0.346000 289.000000 440.000000 1.038980 \n", + "\n", + " pH sulphates alcohol quality \n", + "count 4898.000000 4898.000000 4898.000000 4898.000000 \n", + "mean 3.188267 0.489847 10.514267 5.877909 \n", + "std 0.151001 0.114126 1.230621 0.885639 \n", + "min 2.720000 0.220000 8.000000 3.000000 \n", + "25% 3.090000 0.410000 9.500000 5.000000 \n", + "50% 3.180000 0.470000 10.400000 6.000000 \n", + "75% 3.280000 0.550000 11.400000 6.000000 \n", + "max 3.820000 1.080000 14.200000 9.000000 " + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.read_csv(\"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv\", sep=\";\")\n", + "\n", + "loader = GenericDataLoader(df, target_column=\"quality\", sensitive_columns=[])\n", + "loader.dataframe().describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "14bca1cd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-03-27T18:08:18.761007+0200][38480][INFO] Step 100: MLoss: 1.2836 GLoss: 0.9867 Sum: 2.2703\n", + "[2023-03-27T18:08:24.679745+0200][38480][INFO] Step 200: MLoss: 1.2622 GLoss: 0.9409 Sum: 2.2031\n", + "[2023-03-27T18:08:30.391531+0200][38480][INFO] Step 300: MLoss: 1.2059 GLoss: 0.7669 Sum: 1.9727999999999999\n", + "[2023-03-27T18:08:36.164268+0200][38480][INFO] Step 400: MLoss: 1.1645 GLoss: 0.6393 Sum: 1.8038\n", + "[2023-03-27T18:08:41.835318+0200][38480][INFO] Step 500: MLoss: 1.1717 GLoss: 0.6158 Sum: 1.7875\n", + "[2023-03-27T18:08:47.581383+0200][38480][INFO] Step 600: MLoss: 1.1946 GLoss: 0.5384 Sum: 1.733\n", + "[2023-03-27T18:08:53.378127+0200][38480][INFO] Step 700: MLoss: 1.1343 GLoss: 0.5135 Sum: 1.6478000000000002\n", + "[2023-03-27T18:08:59.698145+0200][38480][INFO] Step 800: MLoss: 1.1168 GLoss: 0.4788 Sum: 1.5956000000000001\n", + "[2023-03-27T18:09:05.752638+0200][38480][INFO] Step 900: MLoss: 1.1034 GLoss: 0.4734 Sum: 1.5768\n", + "[2023-03-27T18:09:12.070003+0200][38480][INFO] Step 1000: MLoss: 1.142 GLoss: 0.4692 Sum: 1.6112\n", + "[2023-03-27T18:09:18.112377+0200][38480][INFO] Step 1100: MLoss: 1.1691 GLoss: 0.4602 Sum: 1.6293\n", + "[2023-03-27T18:09:25.549484+0200][38480][INFO] Step 1200: MLoss: 1.1201 GLoss: 0.4578 Sum: 1.5779\n", + "[2023-03-27T18:09:31.574874+0200][38480][INFO] Step 1300: MLoss: 1.1436 GLoss: 0.4429 Sum: 1.5865\n", + "[2023-03-27T18:09:37.672797+0200][38480][INFO] Step 1400: MLoss: 1.1093 GLoss: 0.449 Sum: 1.5583\n", + "[2023-03-27T18:09:44.149652+0200][38480][INFO] Step 1500: MLoss: 1.1468 GLoss: 0.4347 Sum: 1.5815000000000001\n", + "[2023-03-27T18:09:49.923915+0200][38480][INFO] Step 1600: MLoss: 1.1545 GLoss: 0.4313 Sum: 1.5858\n", + "[2023-03-27T18:09:55.733558+0200][38480][INFO] Step 1700: MLoss: 1.102 GLoss: 0.4305 Sum: 1.5325000000000002\n", + "[2023-03-27T18:10:03.367053+0200][38480][INFO] Step 1800: MLoss: 1.0953 GLoss: 0.4267 Sum: 1.522\n", + "[2023-03-27T18:10:10.533359+0200][38480][INFO] Step 1900: MLoss: 1.1247 GLoss: 0.4223 Sum: 1.5470000000000002\n", + "[2023-03-27T18:10:17.355705+0200][38480][INFO] Step 2000: MLoss: 1.2767 GLoss: 0.4266 Sum: 1.7033\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# define the model hyper-parameters\n", + "plugin_params.update(\n", + " is_classification = False,\n", + " n_iter = 500, # epochs\n", + " lr = 5e-4,\n", + " weight_decay = 1e-4,\n", + " batch_size = 1250,\n", + " n_layers_hidden = 3,\n", + " dim_hidden = 256,\n", + " num_timesteps = 100, # timesteps in diffusion\n", + ")\n", + "plugin = Plugins().get(\"ddpm\", **plugin_params)\n", + "plugin.fit(loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "83064f94", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plugin.loss_history.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "af9d6df1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
053.753993-2.4752390.404968406.8983861.788962450.0737241221.5510775.69717712.45137714.83544581.5156960.0
1241.769932-33.9059337.4401881722.53605348.0345561312.2644572141.11936831.25907483.65474928.591759489.1726743.0
225.3449040.769463-11.237007-335.794326-3.595284-234.179124382.9075157.63768417.7483003.38029673.7010481.0
315.635557-28.371864-19.808469800.08844661.404066-596.053591-1749.79750528.376345-71.868790-14.556346-38.3151791.0
4-0.796959-8.546869-4.726590128.3430281.083628-288.3521041184.6802738.08150023.0128282.16859736.6728400.0
5-31.203381-39.052177-57.6510321269.158981-22.793850101.490751-661.9978235.01273819.61582226.791456-63.7736783.0
6-120.526480-49.314650-67.642982650.13681665.155843598.106999-3468.7530373.75056652.556860-108.310847-91.8163103.0
713.172627-7.196406-20.153565746.262383-30.8466881592.8153971610.699379-15.57666027.31969245.376814135.8714220.0
\n", + "
" + ], + "text/plain": [ + " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", + "0 53.753993 -2.475239 0.404968 406.898386 1.788962 \n", + "1 241.769932 -33.905933 7.440188 1722.536053 48.034556 \n", + "2 25.344904 0.769463 -11.237007 -335.794326 -3.595284 \n", + "3 15.635557 -28.371864 -19.808469 800.088446 61.404066 \n", + "4 -0.796959 -8.546869 -4.726590 128.343028 1.083628 \n", + "5 -31.203381 -39.052177 -57.651032 1269.158981 -22.793850 \n", + "6 -120.526480 -49.314650 -67.642982 650.136816 65.155843 \n", + "7 13.172627 -7.196406 -20.153565 746.262383 -30.846688 \n", + "\n", + " free sulfur dioxide total sulfur dioxide density pH \\\n", + "0 450.073724 1221.551077 5.697177 12.451377 \n", + "1 1312.264457 2141.119368 31.259074 83.654749 \n", + "2 -234.179124 382.907515 7.637684 17.748300 \n", + "3 -596.053591 -1749.797505 28.376345 -71.868790 \n", + "4 -288.352104 1184.680273 8.081500 23.012828 \n", + "5 101.490751 -661.997823 5.012738 19.615822 \n", + "6 598.106999 -3468.753037 3.750566 52.556860 \n", + "7 1592.815397 1610.699379 -15.576660 27.319692 \n", + "\n", + " sulphates alcohol quality \n", + "0 14.835445 81.515696 0.0 \n", + "1 28.591759 489.172674 3.0 \n", + "2 3.380296 73.701048 1.0 \n", + "3 -14.556346 -38.315179 1.0 \n", + "4 2.168597 36.672840 0.0 \n", + "5 26.791456 -63.773678 3.0 \n", + "6 -108.310847 -91.816310 3.0 \n", + "7 45.376814 135.871422 0.0 " + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin.model.generate(8)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "be62c2f0", + "metadata": {}, + "source": [ + "### Conditional data generation\n", + "\n", + "A conditional variable `cond` can be provided to the `fit` method. It can be either a column name in the dataset or a custom array. The model will then learn the conditional distribution of the dataset given `cond`. In this case, an array must be provided as the `cond` argument of the `generate` method." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "56a1fc7e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-03-27T18:03:45.005934+0200][38480][INFO] Step 100: MLoss: 0.9066 GLoss: 1.0013 Sum: 1.9079000000000002\n", + "[2023-03-27T18:03:51.387087+0200][38480][INFO] Step 200: MLoss: 0.4735 GLoss: 1.0112 Sum: 1.4847000000000001\n", + "[2023-03-27T18:03:59.107456+0200][38480][INFO] Step 300: MLoss: 0.4567 GLoss: 1.001 Sum: 1.4577\n", + "[2023-03-27T18:04:05.835508+0200][38480][INFO] Step 400: MLoss: 0.2715 GLoss: 0.9856 Sum: 1.2571\n", + "[2023-03-27T18:04:12.739590+0200][38480][INFO] Step 500: MLoss: 0.2193 GLoss: 0.9046 Sum: 1.1239\n", + "[2023-03-27T18:04:19.417762+0200][38480][INFO] Step 600: MLoss: 0.0143 GLoss: 0.8463 Sum: 0.8606\n", + "[2023-03-27T18:04:26.022729+0200][38480][INFO] Step 700: MLoss: 0.0048 GLoss: 0.7509 Sum: 0.7557\n", + "[2023-03-27T18:04:32.757598+0200][38480][INFO] Step 800: MLoss: 0.0083 GLoss: 0.7102 Sum: 0.7185\n", + "[2023-03-27T18:04:39.550873+0200][38480][INFO] Step 900: MLoss: 0.0029 GLoss: 0.675 Sum: 0.6779000000000001\n", + "[2023-03-27T18:04:46.573464+0200][38480][INFO] Step 1000: MLoss: 0.0039 GLoss: 0.6414 Sum: 0.6453\n", + "[2023-03-27T18:04:53.438631+0200][38480][INFO] Step 1100: MLoss: 0.003 GLoss: 0.6046 Sum: 0.6076\n", + "[2023-03-27T18:05:01.283222+0200][38480][INFO] Step 1200: MLoss: 0.0013 GLoss: 0.6297 Sum: 0.631\n", + "[2023-03-27T18:05:08.559280+0200][38480][INFO] Step 1300: MLoss: 0.0012 GLoss: 0.5479 Sum: 0.5491\n", + "[2023-03-27T18:05:15.536738+0200][38480][INFO] Step 1400: MLoss: 0.0067 GLoss: 0.5275 Sum: 0.5342\n", + "[2023-03-27T18:05:22.391711+0200][38480][INFO] Step 1500: MLoss: 0.0007 GLoss: 0.5252 Sum: 0.5259\n", + "[2023-03-27T18:05:29.285959+0200][38480][INFO] Step 1600: MLoss: 0.0018 GLoss: 0.5017 Sum: 0.5035000000000001\n", + "[2023-03-27T18:05:36.288634+0200][38480][INFO] Step 1700: MLoss: 0.0012 GLoss: 0.5013 Sum: 0.5025\n", + "[2023-03-27T18:05:43.485831+0200][38480][INFO] Step 1800: MLoss: 0.0009 GLoss: 0.4927 Sum: 0.49360000000000004\n", + "[2023-03-27T18:05:50.629387+0200][38480][INFO] Step 1900: MLoss: 0.0009 GLoss: 0.4931 Sum: 0.494\n", + "[2023-03-27T18:05:58.709478+0200][38480][INFO] Step 2000: MLoss: 0.0006 GLoss: 0.4864 Sum: 0.487\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin.fit(loader, cond='quality')" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "3fcb9493", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plugin.loss_history.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "2ea981cd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([3, 4, 5, 6, 7, 8, 9])" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outcome = np.array([3, 4, 5, 6, 7, 8, 9])\n", + "outcome" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "bbd33233", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-03-27T18:05:59.734678+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:05:59.737612+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:00.157952+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:00.160095+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:00.484737+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:00.485757+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:00.786487+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:00.788466+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:01.100020+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:01.102261+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:01.460078+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:01.462163+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:01.805568+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:01.807568+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:02.183897+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:02.185904+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:02.569835+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:02.571874+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:03.033272+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:03.035146+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:03.579187+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:03.582312+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:04.128201+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:04.131216+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:04.909594+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:04.912681+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:05.506491+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:05.509890+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:06.285555+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:06.287092+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:06.748144+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:06.751143+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:07.239364+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:07.241364+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:07.833861+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:07.835862+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:08.270020+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:08.273103+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:08.579762+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:08.581664+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:08.995746+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:08.996750+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:09.387130+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:09.389133+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:09.913255+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:09.915271+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:10.414403+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:10.417511+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:10.986099+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:10.988092+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:11.384006+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 0. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:11.699391+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:11.700392+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:12.138923+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:12.140923+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:12.604077+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:12.606674+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:12.997333+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:12.999663+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:13.547570+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:13.550541+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:13.954516+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:13.956516+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:14.445116+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:14.452112+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:14.829066+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:14.832071+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:15.312829+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:15.315831+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:15.757355+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:15.759926+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:16.143136+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:16.145134+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:16.560027+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:16.562025+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:16.861918+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:16.863918+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:17.183558+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 0. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:17.637582+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:17.640281+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:17.997687+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:17.998689+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:18.345381+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:18.347383+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:18.676026+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:18.678607+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:19.007549+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:19.010506+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:19.346424+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:19.348531+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:19.696186+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:19.697186+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:20.073809+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:20.077249+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:06:20.472414+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:20.475399+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:20.942265+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:20.944268+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:21.302342+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:21.304314+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:21.665401+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:21.666987+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:22.067854+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:22.069371+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:22.392718+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:22.395677+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:22.716515+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:22.717515+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:23.047434+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:23.049434+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:23.399152+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:23.401151+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:23.745625+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:23.747624+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:24.098540+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:24.099540+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:24.421854+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:24.422839+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:24.738758+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:24.739667+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:25.058648+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:25.060550+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:25.399681+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:25.401599+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:25.737806+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:25.738793+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:26.069784+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:26.071290+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:26.416549+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:26.418554+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:26.801542+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:26.803529+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:27.139240+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:27.141225+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:27.488070+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:27.490052+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:27.823788+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:27.824814+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:28.163857+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:28.166838+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:28.499341+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:28.501342+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:28.823408+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:28.825499+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:29.125222+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:29.128129+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:29.492914+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:29.496428+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:29.833079+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:29.835167+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:30.217776+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:30.219777+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:30.536676+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:30.538667+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:30.861816+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:30.863812+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:31.177127+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:31.180126+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:31.606751+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:31.607978+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:31.949768+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:31.951785+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:32.289786+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:32.291671+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:32.629730+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 0. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:32.942556+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:32.945560+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 1. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:32.949202+0200][38480][INFO] [residual sugar] quality loss for constraints ge = 0.6. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:33.286281+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:33.287280+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:33.620445+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:33.622445+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:33.945427+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:33.947494+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:34.298877+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:34.300955+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:34.618880+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:34.620789+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:34.959467+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:34.961383+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:35.296247+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:35.298303+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:35.763113+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:35.765112+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:36.178981+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:36.181338+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:36.555008+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:36.555991+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:36.880093+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:36.881104+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:37.299044+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:37.301205+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:37.708557+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:37.711544+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:38.087165+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:38.089166+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:38.482563+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:38.483562+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:38.941184+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:38.942166+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:39.291995+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:39.294883+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:39.642425+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:39.645485+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:39.965926+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:39.967445+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:40.280863+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:40.281866+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:40.567363+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:40.569362+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:40.863820+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:40.865893+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:41.406311+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:41.409435+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:42.003319+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:42.006307+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:42.470804+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:42.471786+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:42.768361+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:42.770360+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:43.102405+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:43.105718+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:06:43.426329+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:43.429478+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:43.757004+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:43.759124+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:44.083407+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:44.084408+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:44.400443+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:44.401428+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:44.706402+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:44.708999+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:45.018534+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:45.019535+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:45.519397+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:45.521407+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:45.921477+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:45.922985+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:46.265432+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:46.267956+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:46.717722+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:46.719733+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:47.062693+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:47.064691+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:47.417125+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:47.418108+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:47.758309+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:47.760595+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:48.135817+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:48.137801+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:48.458595+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:48.460608+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:48.754069+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:48.756024+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:49.049862+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:49.051462+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:49.350537+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:49.352536+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:49.766318+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:49.769390+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:50.276306+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:50.279351+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:50.665664+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:50.666685+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:51.009462+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:51.012707+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:51.308313+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:51.309313+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:51.637138+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:51.639120+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:51.979944+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:51.980946+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:52.297063+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:52.298062+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:52.625280+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:52.628303+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:52.938341+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:52.939345+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:53.233624+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:53.235624+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:53.550284+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:53.552284+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:53.859100+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:53.863229+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:54.227895+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:54.229895+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:54.534473+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:54.536457+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:06:54.835486+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:54.837487+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:55.132594+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:55.134593+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:55.465635+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:55.467185+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:55.807745+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:55.810517+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:56.300336+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:56.302923+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:56.604424+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:56.605423+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:56.898530+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:56.900544+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:57.205520+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:57.206520+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:57.503438+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:57.505437+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:57.819558+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:57.821581+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:06:58.160813+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:58.163417+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:58.462315+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:58.463303+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:06:58.815614+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:58.817596+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:06:59.129940+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:59.130934+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:06:59.577632+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:59.580621+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:06:59.909210+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:06:59.910211+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:00.263906+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:00.265906+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:00.573175+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:00.574177+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:00.866210+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:00.868793+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:01.205344+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:01.207327+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:01.606906+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:01.608906+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:02.102300+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:02.105211+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:02.503969+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:02.506485+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:02.906864+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:02.908864+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:03.298141+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:03.300142+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:03.619687+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:03.621670+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:03.942307+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:03.946964+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:04.383317+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:04.384318+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:04.685032+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:04.687584+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:04.985829+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:04.986829+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:07:05.266858+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:05.269157+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:05.580166+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:05.582149+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:05.889785+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:05.892186+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:06.211209+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:06.213722+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:06.513714+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:06.515729+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:06.832167+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:06.834177+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:07.144798+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:07.146797+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:07.479304+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:07.481835+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:07.846999+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:07.848997+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:08.195789+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:08.197813+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 1. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:08.200813+0200][38480][INFO] [residual sugar] quality loss for constraints le = 65.8. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:08.691113+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:08.694249+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:09.231893+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:09.235438+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:09.713446+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:09.716162+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:07:10.805837+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:10.809012+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:11.446846+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:11.450600+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:12.110297+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:12.114136+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:12.587219+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:12.589217+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:13.186604+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:13.188628+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:13.765722+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:13.767730+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:14.222493+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:14.225273+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:14.581621+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:14.582622+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:14.916005+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:14.917005+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:15.232768+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:15.233771+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:15.587426+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:15.589426+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:15.937914+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:15.939914+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:16.341209+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:16.343228+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:16.667291+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:16.669292+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:16.989838+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:16.991912+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:17.306825+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:17.308797+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:17.659105+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:17.661131+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:18.018946+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:18.019947+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:18.393086+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:18.396311+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:07:18.830421+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:18.833527+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:19.232926+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:19.236012+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:19.669845+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:19.672139+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:20.034654+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:20.035654+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:20.365288+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:20.367291+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:20.677852+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:20.680692+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:20.988636+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:20.990732+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:21.326922+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:21.329905+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:21.682149+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:21.684150+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:22.042272+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:22.043272+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:22.417916+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:22.418916+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:22.749237+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:22.751237+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:23.090475+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:23.091459+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:23.470508+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:23.473305+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:23.821072+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:23.823567+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 1. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:23.827191+0200][38480][INFO] [residual sugar] quality loss for constraints ge = 0.6. Remaining 0. prev length 1. Original dtype float64.\n", + "[2023-03-27T18:07:24.193607+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:24.194590+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:24.532529+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:24.534525+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:24.876586+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:24.878585+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:25.216076+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:25.217076+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:25.599528+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:25.601333+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:26.159795+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:26.161982+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:26.541276+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:26.542274+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:26.869887+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:26.872038+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:27.183814+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:27.186139+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:27.522592+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:27.524574+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:27.885528+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:27.886547+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", + "[2023-03-27T18:07:28.236311+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:28.237310+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", + "[2023-03-27T18:07:28.569622+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:28.571622+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:28.889372+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:28.890372+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:29.200272+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:29.202272+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:29.533137+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:29.535216+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", + "[2023-03-27T18:07:29.936280+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:29.939026+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:30.369796+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:30.371797+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", + "[2023-03-27T18:07:30.718054+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:30.720128+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", + "[2023-03-27T18:07:31.139806+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", + "[2023-03-27T18:07:31.140809+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mplugin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moutcome\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\plugin.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, count, constraints, random_state, **kwargs)\u001b[0m\n\u001b[0;32m 337\u001b[0m \u001b[0msyn_schema\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mSchema\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_constraints\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgen_constraints\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 338\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 339\u001b[1;33m \u001b[0mX_syn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_generate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msyn_schema\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msyn_schema\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 340\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 341\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mX_syn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_tabular\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_ddpm.py\u001b[0m in \u001b[0;36m_generate\u001b[1;34m(self, count, syn_schema, **kwargs)\u001b[0m\n\u001b[0;32m 246\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 247\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 248\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_safe_generate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msyn_schema\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 249\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 250\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\plugin.py\u001b[0m in \u001b[0;36m_safe_generate\u001b[1;34m(self, gen_cbk, count, syn_schema, **kwargs)\u001b[0m\n\u001b[0;32m 391\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mit\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msampling_patience\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 392\u001b[0m \u001b[1;31m# sample\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 393\u001b[1;33m \u001b[0miter_samples\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgen_cbk\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 394\u001b[0m iter_samples_df = pd.DataFrame(\n\u001b[0;32m 395\u001b[0m \u001b[0miter_samples\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining_schema\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_ddpm.py\u001b[0m in \u001b[0;36mcallback\u001b[1;34m(count)\u001b[0m\n\u001b[0;32m 241\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 242\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# type: ignore\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 243\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcond\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 244\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_classification\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 245\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_iloc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\__init__.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, count, cond)\u001b[0m\n\u001b[0;32m 211\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcond\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 212\u001b[0m \u001b[0mcond\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcond\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 213\u001b[1;33m \u001b[0msample\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdiffusion\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 214\u001b[0m \u001b[0msample\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msample\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_col_perm\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0msample\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\gaussian_multinomial_diffsuion.py\u001b[0m in \u001b[0;36msample_all\u001b[1;34m(self, num_samples, cond, max_batch_size, ddim)\u001b[0m\n\u001b[0;32m 951\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 952\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mb\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mbs\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 953\u001b[1;33m \u001b[0msample\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msample_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 954\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0many\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msample\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 955\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"found NaNs in sample\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\autograd\\grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 27\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 28\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 29\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\gaussian_multinomial_diffsuion.py\u001b[0m in \u001b[0;36msample\u001b[1;34m(self, num_samples, cond)\u001b[0m\n\u001b[0;32m 918\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"Sample timestep {i:4d}\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"\\r\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 919\u001b[0m \u001b[0mt\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfull\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 920\u001b[1;33m model_out = self.denoise_fn(\n\u001b[0m\u001b[0;32m 921\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mz_norm\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlog_z\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcond\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 922\u001b[0m )\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\modules.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, t, y)\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[0memb\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0memb_nonlin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlabel_emb\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0memb\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 113\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\mlp.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m 398\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mvalidate_arguments\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marbitrary_types_allowed\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 399\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 400\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 401\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 402\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_train_epoch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloader\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mfloat\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 202\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 203\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 204\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 205\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 206\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\mlp.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mvalidate_arguments\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marbitrary_types_allowed\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 114\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 116\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 202\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 203\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 204\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 205\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 206\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 114\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 116\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "plugin.generate(len(outcome), cond=outcome)" + ] + }, + { + "cell_type": "markdown", + "id": "ea5abc50", + "metadata": {}, + "source": [ + "## Congratulations!\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!\n", + "\n", + "### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub\n", + "\n", + "- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.\n", + "\n", + "\n", + "### Checkout other projects from vanderschaarlab\n", + "- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)\n", + "- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 8120e975dc0279e294847f6f2a2e44f1990775cc Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 30 Mar 2023 16:01:53 +0200 Subject: [PATCH 29/95] major update of FeatureEncoder and TabularEncoder --- .../plugins/core/models/data_encoder.py | 198 ++++++++----- .../gaussian_multinomial_diffsuion.py | 7 +- .../plugins/core/models/tabular_encoder.py | 280 +++++++----------- ...al8_tabular_modelling_with_diffusion.ipynb | 1 - 4 files changed, 230 insertions(+), 256 deletions(-) diff --git a/src/synthcity/plugins/core/models/data_encoder.py b/src/synthcity/plugins/core/models/data_encoder.py index 9f432d9c..57fdbc1c 100644 --- a/src/synthcity/plugins/core/models/data_encoder.py +++ b/src/synthcity/plugins/core/models/data_encoder.py @@ -1,6 +1,5 @@ # stdlib -from functools import wraps -from typing import Any, List, Optional, Union +from typing import Any, List, Type, Union # third party import numpy as np @@ -15,68 +14,111 @@ StandardScaler, ) +FeatureEncoder = Any -class _DataEncoder(TransformerMixin, BaseEstimator): - """Base data encoder, with sklearn-style API""" - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def fit(self, X: Any) -> Any: - return self._fit(X) +class FeatureEncoder(TransformerMixin, BaseEstimator): # type: ignore + """Base feature encoder, with sklearn-style API""" - def _fit(self, X: Any) -> Any: - return self + def __new__(cls, **kwargs: Any) -> FeatureEncoder: + obj = super().__new__() + obj.__dict__.update(kwargs) # auto set all parameters as attributes + return obj @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def transform(self, X: Any) -> Any: - return self._transform(X) + def fit(self, x: pd.Series, y: Any = None, **kwargs: Any) -> FeatureEncoder: + self.feature_name_in = x.name + out = self._fit(x, **kwargs)._transform(x) - def _transform(self, X: Any) -> Any: - return X + if np.ndim(out) == 1: + self.n_features_out = 1 + else: + self.n_features_out = np.shape(out)[1] - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def inverse_transform(self, X: Any) -> Any: - return self._inverse_transform(X) + self.feature_names_out = self.get_feature_names_out() - def _inverse_transform(self, X: Any) -> Any: - return X + return self + + def _fit(self, x: pd.Series, **kwargs: Any) -> FeatureEncoder: + return self + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def transform(self, x: pd.Series) -> Any: + out = self._transform(x) + if isinstance(out, np.ndarray): + if out.ndim == 1: + return pd.Series(out, self.feature_name_in) + else: + return pd.DataFrame(out, columns=self.feature_names_out) + return out + + def _transform(self, x: pd.Series) -> Any: + return x + + def get_feature_names_out(self) -> List[str]: + n = self.n_features_out + if n == 1: + return [self.feature_name_in] + else: + return [self.feature_name_in + str(i) for i in range(n)] @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def fit_transform(self, X: Any) -> Any: - return self.fit(X).transform(X) + def inverse_transform(self, data: Any) -> pd.Series: + x = self._inverse_transform(data) + return pd.Series(x, name=self.feature_name_in) + + def _inverse_transform(self, data: Any) -> pd.Series: + return data @classmethod - def wraps(cls, encoder_class: TransformerMixin) -> type: - """Wraps sklearn encoder to DataEncoder.""" + def wraps(cls, encoder_class: TransformerMixin) -> Type[FeatureEncoder]: + """Wraps sklearn transformer to FeatureEncoder.""" - @wraps(encoder_class) - class WrappedEncoder(_DataEncoder): + class WrappedEncoder(FeatureEncoder): def __init__(self, *args: Any, **kwargs: Any) -> None: self.encoder = encoder_class(*args, **kwargs) - def _fit(self, X: Any) -> _DataEncoder: - self.encoder.fit(X) + def _fit(self, x: pd.Series, **kwargs: Any) -> FeatureEncoder: + self.encoder.fit(x, **kwargs) return self - def _transform(self, X: Any) -> Any: - return self.encoder.transform(X) + def _transform(self, x: pd.Series) -> Any: + return self.encoder.transform(x) - def _inverse_transform(self, X: Any) -> Any: - return self.encoder.inverse_transform(X) + def _inverse_transform(self, x: pd.Series) -> Any: + return self.encoder.inverse_transform(x) + + def get_feature_names_out(self) -> List[str]: + return self.encoder.get_feature_names_out([self.feature_name_in]) + + for attr in ( + "__module__", + "__name__", + "__qualname__", + "__doc__", + "__annotations__", + ): + setattr(WrappedEncoder, attr, getattr(encoder_class, attr)) return WrappedEncoder -class DatetimeEncoder(_DataEncoder): +OneHotEncoder = FeatureEncoder.wraps(OneHotEncoder) +StandardScaler = FeatureEncoder.wraps(StandardScaler) +MinMaxScaler = FeatureEncoder.wraps(MinMaxScaler) + + +class DatetimeEncoder(FeatureEncoder): """Datetime variables encoder""" - def _transform(self, X: pd.Series) -> pd.Series: - return pd.to_numeric(X).astype(float) + def _transform(self, x: pd.Series) -> pd.Series: + return pd.to_numeric(x).astype(float) - def _inverse_transform(self, X: pd.Series) -> pd.Series: - return pd.to_datetime(X) + def _inverse_transform(self, data: pd.Series) -> pd.Series: + return pd.to_datetime(data) -class BayesianGMMEncoder(_DataEncoder): +class BayesianGMMEncoder(FeatureEncoder): """Bayesian Gaussian Mixture encoder""" def __init__( @@ -84,70 +126,72 @@ def __init__( n_components: int = 10, random_state: int = 0, weight_threshold: float = 0.005, + clip_output: bool = True, + std_multiplier: int = 4, ) -> None: self.model = BayesianGaussianMixture( n_components=n_components, random_state=random_state, weight_concentration_prior=1e-3, ) - self.n_components = n_components - self.weight_threshold = weight_threshold - self.weights: Optional[List[float]] = None - self.std_multiplier = 4 + self.weights: List[float] - def _fit(self, X: pd.DataFrame) -> Any: - self.min_value = X.min() - self.max_value = X.max() + def _fit(self, x: pd.Series, **kwargs: Any) -> "BayesianGaussianMixture": + self.min_value = x.min() + self.max_value = x.max() - self.model.fit(X.values.reshape(-1, 1)) + self.model.fit(x.values.reshape(-1, 1)) self.weights = self.model.weights_ - self.n_components = len(self.model.weights_) + self.means = self.model.means_.reshape(-1) + self.stds = np.sqrt(self.model.covariances_).reshape(-1) return self - def _transform(self, X: pd.DataFrame) -> pd.DataFrame: - name = X.name - X = X.values.reshape(-1, 1) - means = self.model.means_.reshape(1, self.n_components) + def _transform(self, x: pd.Series) -> pd.DataFrame: + x = x.values.reshape(-1, 1) + means = self.means.reshape(1, -1) + stds = self.stds.reshape(1, -1) # predict cluster value - stds = np.sqrt(self.model.covariances_).reshape(1, self.n_components) - - normalized_values = (X - means) / (self.std_multiplier * stds) + normalized_values = (x - means) / (self.std_multiplier * stds) # predict cluster - component_probs = self.model.predict_proba(X) + component_probs = self.model.predict_proba(x) components = np.argmax(component_probs, axis=1) - aranged = np.arange(len(X)) - normalized = normalized_values[aranged, components].reshape([-1, 1]) - normalized = np.clip(normalized, -0.99, 0.99).squeeze(axis=1) - out = np.stack([normalized, components], axis=1) + normalized = normalized_values[np.arange(len(x)), components] + if self.clip_output: + normalized = np.clip(normalized, -0.99, 0.99) + normalized = normalized.reshape(-1, 1) - return pd.DataFrame(out, columns=[f"{name}.value", f"{name}.component"]) + components = np.eye(self.n_components)[components] # onehot + return np.hstack([normalized, components]) + + def get_feature_names_out(self) -> List[str]: + name = self.feature_name_in + return [f"{name}.value"] + [ + f"{name}.component_{i}" for i in range(self.n_features_out - 1) + ] + + def _inverse_transform(self, data: pd.DataFrame) -> pd.Series: + if self.clip_output: + data = np.clip(data.values[:, 0], -1, 1) - def _inverse_transform(self, X: pd.DataFrame) -> pd.DataFrame: - normalized = np.clip(X.values[:, 0], -1, 1) means = self.model.means_.reshape([-1]) stds = np.sqrt(self.model.covariances_).reshape([-1]) - selected_component = X.values[:, 1].astype(int) + components = np.argmax(data.values[:, 1:], axis=1) # recreate data - std_t = stds[selected_component] - mean_t = means[selected_component] - reversed_data = normalized * self.std_multiplier * std_t + mean_t + std_t = stds[components] + mean_t = means[components] + reversed_data = data * self.std_multiplier * std_t + mean_t # clip values return np.clip(reversed_data, self.min_value, self.max_value) -OneHotEncoder = _DataEncoder.wraps(OneHotEncoder) -StandardScaler = _DataEncoder.wraps(StandardScaler) -MinMaxScaler = _DataEncoder.wraps(MinMaxScaler) - - -@_DataEncoder.wraps +@FeatureEncoder.wraps class GaussianQuantileTransformer(QuantileTransformer): """Quantile transformer with Gaussian distribution""" @@ -168,12 +212,12 @@ def __init__( copy=copy, ) - def fit(self, X: pd.DataFrame, y: Any = None) -> "GaussianQuantileTransformer": - self.n_quantiles = max(min(len(X) // 30, 1000), 10) - return super().fit(X, y) + def fit(self, x: pd.Series, y: Any = None) -> "GaussianQuantileTransformer": + self.n_quantiles = max(min(len(x) // 30, 1000), 10) + return super().fit(x, y) -REGISTRY = { +ENCODERS = { "datetime": DatetimeEncoder, "onehot": OneHotEncoder, "standard": StandardScaler, @@ -183,7 +227,7 @@ def fit(self, X: pd.DataFrame, y: Any = None) -> "GaussianQuantileTransformer": } -def get_encoder(encoder: Union[str, type]) -> TransformerMixin: +def get_encoder(encoder: Union[str, type]) -> Type[FeatureEncoder]: """Get a registered encoder. Supported encoders: @@ -198,5 +242,5 @@ def get_encoder(encoder: Union[str, type]) -> TransformerMixin: - bayesian_gmm """ if isinstance(encoder, type): # custom encoder - return encoder - return REGISTRY[encoder] + return FeatureEncoder.wraps(encoder) + return ENCODERS[encoder] diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index db55aedb..990d6cbb 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -946,11 +946,12 @@ def sample_all( else: sample_fn = self.sample - bs = np.diff([*range(0, num_samples, max_batch_size), num_samples]) + indices = [*range(0, num_samples, max_batch_size), num_samples] all_samples = [] - for b in bs: - sample = sample_fn(b, cond) + for i, b in enumerate(np.diff(indices)): + c = None if cond is None else cond[indices[i] : indices[i + 1]] + sample = sample_fn(b, c) if torch.any(sample.isnan()).item(): raise ValueError("found NaNs in sample") all_samples.append(sample) diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index a9bb0e82..0eb2a096 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -2,7 +2,7 @@ """ # stdlib -from typing import Any, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple, Union # third party import numpy as np @@ -16,13 +16,13 @@ from synthcity.utils.serialization import dataframe_hash # synthcity relative -from .data_encoder import get_encoder +from .data_encoder import FeatureEncoder, get_encoder class FeatureInfo(BaseModel): name: str feature_type: str - transform: Any + transform: FeatureEncoder output_dimensions: int transformed_features: List[str] @@ -56,95 +56,70 @@ class TabularEncoder(TransformerMixin, BaseEstimator): Discrete columns are encoded using a scikit-learn OneHotEncoder. """ + categorical_encoder: Union[str, type] = "onehot" + continuous_encoder: Union[str, type] = "bayesian_gmm" + cat_encoder_params: dict = dict(handle_unknown="ignore", sparse=False) + cont_encoder_params: dict = dict(n_components=10) + @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__( self, - max_clusters: int = 10, + *, + whitelist: tuple = (), categorical_limit: int = 10, - whitelist: list = [], - categorical_encoder: str = "onehot", - continuous_encoder: str = "bayesian_gmm", + categorical_encoder: Optional[Union[str, type]] = None, + continuous_encoder: Optional[Union[str, type]] = None, + cat_encoder_params: Optional[dict] = None, + cont_encoder_params: Optional[dict] = None, ) -> None: """Create a data transformer. Args: - max_clusters (int): - Maximum number of Gaussian distributions in Bayesian GMM. + whitelist (tuple): + Columns that will not be transformed. """ - self.max_clusters = max_clusters self.categorical_limit = categorical_limit self.whitelist = whitelist - self.categorical_encoder = categorical_encoder - self.continuous_encoder = continuous_encoder - - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def _fit_continuous(self, data: pd.Series) -> FeatureInfo: - """Fit the continuous encoder on a continuous column. - - Args: - data (pd.DataFrame): - A dataframe containing a column. - - Returns: - namedtuple: - A ``FeatureInfo`` object. - """ - name = data.name - - if self.continuous_encoder == "bayesian_gmm": - encoder = get_encoder("bayesian_gmm")( - n_components=min(self.max_clusters, len(data)), - ) - n_components = encoder.n_components - dim_out = 1 + n_components - transformed_features = [f"{name}.value"] + [ - f"{name}.component_{i}" for i in range(n_components) - ] + if categorical_encoder is not None: + self.categorical_encoder = categorical_encoder + if continuous_encoder is not None: + self.continuous_encoder = continuous_encoder + if cat_encoder_params is not None: + self.cat_encoder_params = cat_encoder_params else: - encoder = get_encoder(self.continuous_encoder)() - dim_out = 1 - transformed_features = [name] - - encoder.fit(data) - - return FeatureInfo( - name=name, - feature_type="continuous", - transform=encoder, - output_dimensions=dim_out, - transformed_features=transformed_features, - ) + self.cat_encoder_params = self.cat_encoder_params.copy() + if cont_encoder_params is not None: + self.cont_encoder_params = cont_encoder_params + else: + self.cont_encoder_params = self.cont_encoder_params.copy() @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def _fit_discrete(self, data: pd.Series) -> FeatureInfo: - """Fit one hot encoder for discrete column. + def _fit_feature(self, feature: pd.Series, feature_type: str) -> FeatureInfo: + """Fit the feature encoder on a column. Args: - data (pd.DataFrame): - A dataframe containing a column. + feature (pd.Series): + A column of a dataframe. + feature_type (str): + Type of the feature ('discrete' or 'continuous'). Returns: - namedtuple: - A ``FeatureInfo`` object. + FeatureInfo: + Information of the fitted feature encoder. """ - name = data.name - - if self.categorical_encoder == "onehot": - encoder = get_encoder("onehot")(handle_unknown="ignore", sparse=False) + if feature_type == "discrete": + encoder = get_encoder(self.categorical_encoder)(**self.cat_encoder_params) else: - raise ValueError(f"Unknown categorical encoder {self.categorical_encoder}") + encoder = get_encoder(self.continuous_encoder)(**self.cont_encoder_params) - encoder.fit(data.values.reshape(-1, 1)) - num_categories = len(encoder.categories_[0]) - - transformed_features = list(encoder.get_feature_names_out([data.name])) + encoder.fit(feature) return FeatureInfo( - name=name, - feature_type="discrete", + name=feature.name, + feature_type=feature_type, transform=encoder, - output_dimensions=num_categories, - transformed_features=transformed_features, + output_dimensions=encoder.n_features_out, + transformed_features=encoder.feature_names_out, ) @validate_arguments(config=dict(arbitrary_types_allowed=True)) @@ -161,81 +136,51 @@ def fit( self.output_dimensions = 0 self._column_raw_dtypes = raw_data.infer_objects().dtypes - self._column_transform_info = [] + self._column_transform_info_list = [] for name in raw_data.columns: if name in self.whitelist: continue column_hash = dataframe_hash(raw_data[[name]]) log.info(f"Encoding {name} {column_hash}") - if name in discrete_columns: - column_transform_info = self._fit_discrete(raw_data[name]) + ftype = "discrete" else: - column_transform_info = self._fit_continuous(raw_data[name]) + ftype = "continuous" + column_transform_info = self._fit_feature(raw_data[name], ftype) self.output_dimensions += column_transform_info.output_dimensions - self._column_transform_info.append(column_transform_info) + self._column_transform_info_list.append(column_transform_info) return self - def _transform_continuous( - self, column_transform_info: FeatureInfo, data: pd.Series - ) -> pd.DataFrame: - name = data.name - encoder = column_transform_info.transform - transformed = encoder.transform(data) - - # Converts the transformed data to the appropriate output format. - if self.continuous_encoder == "bayesian_gmm": - output = np.zeros( - (len(transformed), column_transform_info.output_dimensions) - ) - output[:, 0] = transformed[f"{name}.value"].to_numpy() - index = transformed[f"{name}.component"].to_numpy().astype(int) - output[np.arange(index.size), index + 1] = 1 - else: - output = transformed.to_numpy().reshape(-1, 1) - - return pd.DataFrame( - output, - columns=column_transform_info.transformed_features, - ) - - def _transform_discrete( - self, column_transform_info: FeatureInfo, data: pd.Series + def _transform_feature( + self, column_transform_info: FeatureInfo, feature: pd.Series ) -> pd.DataFrame: encoder = column_transform_info.transform return pd.DataFrame( - encoder.transform(data.to_frame().values), + encoder.transform(feature).values, columns=column_transform_info.transformed_features, ) @validate_arguments(config=dict(arbitrary_types_allowed=True)) def transform(self, raw_data: pd.DataFrame) -> pd.DataFrame: """Take raw data and output a matrix data.""" - if len(self._column_transform_info) == 0: + if len(self._column_transform_info_list) == 0: return pd.DataFrame(np.zeros((len(raw_data), 0))) column_data_list = [] for name in self.whitelist: if name not in raw_data.columns: continue - data = raw_data[name] - column_data_list.append(data) + feature = raw_data[name] + column_data_list.append(feature) - for column_transform_info in self._column_transform_info: - name = column_transform_info.name - data = raw_data[name] - - if column_transform_info.feature_type == "continuous": - column_data_list.append( - self._transform_continuous(column_transform_info, data) - ) - else: - column_data_list.append( - self._transform_discrete(column_transform_info, data) - ) + for column_transform_info in self._column_transform_info_list: + feature = raw_data[column_transform_info.name] + column_data_list.append( + self._transform_feature(column_transform_info, feature) + ) result = pd.concat(column_data_list, axis=1) result.index = raw_data.index @@ -243,31 +188,13 @@ def transform(self, raw_data: pd.DataFrame) -> pd.DataFrame: return result @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def _inverse_transform_continuous( + def _inverse_transform_feature( self, column_transform_info: FeatureInfo, column_data: pd.DataFrame, - ) -> pd.DataFrame: - encoder = column_transform_info.transform - if self.continuous_encoder == "bayesian_gmm": - data = pd.DataFrame( - column_data.values[:, :2], columns=["value", "component"] - ) - data.iloc[:, 1] = np.argmax(column_data.values[:, 1:], axis=1) - else: - data = column_data - return encoder.inverse_transform(data) - - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def _inverse_transform_discrete( - self, column_transform_info: FeatureInfo, column_data: pd.DataFrame - ) -> pd.DataFrame: + ) -> pd.Series: encoder = column_transform_info.transform - column = column_transform_info.name - return pd.DataFrame( - encoder.inverse_transform(column_data), - columns=[column], - ) + return encoder.inverse_transform(column_data) @validate_arguments(config=dict(arbitrary_types_allowed=True)) def inverse_transform(self, data: pd.DataFrame) -> pd.DataFrame: @@ -275,39 +202,32 @@ def inverse_transform(self, data: pd.DataFrame) -> pd.DataFrame: Output uses the same type as input to the transform function. """ - if len(self._column_transform_info) == 0: + if len(self._column_transform_info_list) == 0: return pd.DataFrame(np.zeros((len(data), 0))) st = 0 - recovered_column_data_list = [] names = [] feature_types = [] + recovered_feature_list = [] for name in self.whitelist: if name not in data.columns: continue - local_data = data[name] names.append(name) feature_types.append(self._column_raw_dtypes) - recovered_column_data_list.append(local_data) + recovered_feature_list.append(data[name]) - for column_transform_info in self._column_transform_info: + for column_transform_info in self._column_transform_info_list: dim = column_transform_info.output_dimensions column_data = data.iloc[:, list(range(st, st + dim))] - if column_transform_info.feature_type == "continuous": - recovered_column_data = self._inverse_transform_continuous( - column_transform_info, column_data - ) - else: - recovered_column_data = self._inverse_transform_discrete( - column_transform_info, column_data - ) - - recovered_column_data_list.append(recovered_column_data) + recovered_feature = self._inverse_transform_feature( + column_transform_info, column_data + ) + recovered_feature_list.append(recovered_feature) names.append(column_transform_info.name) st += dim - recovered_data = np.column_stack(recovered_column_data_list) + recovered_data = np.column_stack(recovered_feature_list) recovered_data = pd.DataFrame( recovered_data, columns=names, index=data.index ).astype(self._column_raw_dtypes.filter(names)) @@ -320,18 +240,16 @@ def layout(self) -> List[Tuple]: - continuous, and with length 1 + number of GMM clusters. - discrete, and with length , the length of the one-hot encoding. """ - return self._column_transform_info + return self._column_transform_info_list def n_features(self) -> int: return np.sum( - [ - column_transform_info.output_dimensions - for column_transform_info in self._column_transform_info - ] + column_transform_info.output_dimensions + for column_transform_info in self._column_transform_info_list ) def get_column_info(self, name: str) -> FeatureInfo: - for column_transform_info in self._column_transform_info: + for column_transform_info in self._column_transform_info_list: if column_transform_info.name == name: return column_transform_info @@ -348,7 +266,7 @@ def activation_layout( - discrete, and with length , the length of the one-hot encoding. """ out = [] - for column_transform_info in self._column_transform_info: + for column_transform_info in self._column_transform_info_list: if column_transform_info.feature_type == "continuous": out.extend( [ @@ -374,26 +292,38 @@ class BinEncoder(TabularEncoder): Discrete columns are encoded using a scikit-learn OneHotEncoder. """ - def _transform_continuous( - self, column_transform_info: FeatureInfo, data: pd.Series - ) -> pd.Series: - name = data.name - encoder = column_transform_info.transform - transformed = encoder.transform(data) - return transformed[f"{name}.component"].to_numpy().astype(int) + continuous_encoder = "bayesian_gmm" + cont_encoder_params = dict(n_components=2) + categorical_encoder = "onehot" + cat_encoder_params = dict(handle_unknown="ignore", sparse=False) - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def transform(self, raw_data: pd.DataFrame) -> pd.DataFrame: - """Take raw data and output a matrix data.""" - output = raw_data.copy() - - for column_transform_info in self._column_transform_info: - name = column_transform_info.name - output[name] = self._transform_continuous( - column_transform_info, raw_data[name] + # TODO: check if this is correct + def _transform_feature( + self, column_transform_info: FeatureInfo, feature: pd.Series + ) -> pd.DataFrame: + if column_transform_info.feature_type == "discrete": + return super()._transform_feature(column_transform_info, feature) + bgm = column_transform_info.transform + out = bgm.transform(feature) + if out.shape != (len(feature), 3): + raise ValueError( + "BinEncoder should transform continuous features using a " + "BayesianGMM with 2 components" ) + # encoded as a binary vector corresponding to the first component + return pd.DataFrame(out.values[:, [1]], columns=[bgm.feature_name_in]) - return output + def _inverse_transform_feature( + self, column_transform_info: FeatureInfo, column_data: pd.DataFrame + ) -> pd.Series: + if column_transform_info == "discrete": + return super()._inverse_transform_feature( + column_transform_info, column_data + ) + bgm = column_transform_info.transform + components = column_data.values.reshape(-1) + features = bgm.means[components] + return pd.Series(features, name=bgm.feature_name_in) class TimeSeriesTabularEncoder(TransformerMixin, BaseEstimator): diff --git a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb index 97e38401..b520308e 100644 --- a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb +++ b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb @@ -38,7 +38,6 @@ "# stdlib\n", "import sys\n", "import warnings\n", - "sys.path.insert(0, '../src')\n", "\n", "# third party\n", "import numpy as np\n", From 2750791100fb9d4e30d46621d8f3b2fbc5e8c1a7 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 30 Mar 2023 16:30:54 +0200 Subject: [PATCH 30/95] add LogDistribution and LogIntDistribution --- src/synthcity/plugins/core/distribution.py | 41 ++++++++++++++++--- .../plugins/core/models/data_encoder.py | 4 +- src/synthcity/plugins/generic/plugin_ddpm.py | 20 +++++---- 3 files changed, 50 insertions(+), 15 deletions(-) diff --git a/src/synthcity/plugins/core/distribution.py b/src/synthcity/plugins/core/distribution.py index fb486e0a..3aff6f74 100644 --- a/src/synthcity/plugins/core/distribution.py +++ b/src/synthcity/plugins/core/distribution.py @@ -260,7 +260,7 @@ def max(self) -> Any: return self.high def __eq__(self, other: Any) -> bool: - if not isinstance(other, FloatDistribution): + if not isinstance(other, type(self)): return False return ( @@ -273,6 +273,21 @@ def dtype(self) -> str: return "float" +class LogDistribution(FloatDistribution): + low: float = np.iinfo(np.int64).min + high: float = np.iinfo(np.int64).max + base: float = 10.0 + _log_low: float = np.log(low) / np.log(base) + _log_high: float = np.log(high) / np.log(base) + + def sample(self, count: int = 1) -> Any: + np.random.seed(self.random_state) + msamples = self.sample_marginal(count) + if msamples is not None: + return msamples + return self.base ** np.random.uniform(self._log_low, self._log_high, count) + + class IntegerDistribution(Distribution): """ .. inheritance-diagram:: synthcity.plugins.core.distribution.IntegerDistribution @@ -345,7 +360,20 @@ def dtype(self) -> str: return "int" -OFFSET = 120 +class LogIntDistribution(FloatDistribution): + low: int = np.iinfo(np.int64).min + high: int = np.iinfo(np.int64).max + base: float = 10.0 + _log_low: float = np.log(low) / np.log(base) + _log_high: float = np.log(high) / np.log(base) + + def sample(self, count: int = 1) -> Any: + np.random.seed(self.random_state) + msamples = self.sample_marginal(count) + if msamples is not None: + return msamples + s = self.base ** np.random.uniform(self._log_low, self._log_high, count) + return s.astype(int) class DatetimeDistribution(Distribution): @@ -356,6 +384,7 @@ class DatetimeDistribution(Distribution): low: datetime = datetime.utcfromtimestamp(0) high: datetime = datetime.now() + offset: int = 120 @validator("low", always=True) def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime: @@ -363,7 +392,7 @@ def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime: if mkey in values and values[mkey] is not None: v = values[mkey].index.min() - return v - timedelta(seconds=OFFSET) + return v - timedelta(seconds=cls.offset) @validator("high", always=True) def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime: @@ -371,7 +400,7 @@ def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime: if mkey in values and values[mkey] is not None: v = values[mkey].index.max() - return v + timedelta(seconds=OFFSET) + return v + timedelta(seconds=cls.offset) def get(self) -> List[Any]: return [self.name, self.low, self.high] @@ -397,8 +426,8 @@ def has(self, val: datetime) -> bool: def includes(self, other: "Distribution") -> bool: return self.min() - timedelta( - seconds=OFFSET - ) <= other.min() and other.max() <= self.max() + timedelta(seconds=OFFSET) + seconds=self.offset + ) <= other.min() and other.max() <= self.max() + timedelta(seconds=self.offset) def as_constraint(self) -> Constraints: return Constraints( diff --git a/src/synthcity/plugins/core/models/data_encoder.py b/src/synthcity/plugins/core/models/data_encoder.py index 57fdbc1c..90ed7e1e 100644 --- a/src/synthcity/plugins/core/models/data_encoder.py +++ b/src/synthcity/plugins/core/models/data_encoder.py @@ -14,7 +14,7 @@ StandardScaler, ) -FeatureEncoder = Any +FeatureEncoder = Any # tried to use ForwardRef but it didn't work under mypy class FeatureEncoder(TransformerMixin, BaseEstimator): # type: ignore @@ -224,6 +224,7 @@ def fit(self, x: pd.Series, y: Any = None) -> "GaussianQuantileTransformer": "minmax": MinMaxScaler, "quantile": GaussianQuantileTransformer, "bayesian_gmm": BayesianGMMEncoder, + "passthrough": FeatureEncoder, } @@ -240,6 +241,7 @@ def get_encoder(encoder: Union[str, type]) -> Type[FeatureEncoder]: - minmax - quantile - bayesian_gmm + - Passthrough """ if isinstance(encoder, type): # custom encoder return FeatureEncoder.wraps(encoder) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 855eb4ec..b7f24974 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -15,7 +15,12 @@ # synthcity absolute from synthcity.plugins.core.dataloader import DataLoader -from synthcity.plugins.core.distribution import CategoricalDistribution, Distribution +from synthcity.plugins.core.distribution import ( + Distribution, + IntegerDistribution, + LogDistribution, + LogIntDistribution, +) from synthcity.plugins.core.models.tabular_ddpm import TabDDPM from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema @@ -174,13 +179,12 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]: Gaussian diffusion loss MSE """ return [ - # TODO: change to loguniform distribution - CategoricalDistribution(name="lr", choices=[1e-5, 1e-4, 1e-3, 2e-3, 3e-3]), - CategoricalDistribution(name="batch_size", choices=[256, 4096]), - CategoricalDistribution(name="num_timesteps", choices=[100, 1000]), - CategoricalDistribution(name="n_iter", choices=[5000, 10000, 20000]), - CategoricalDistribution(name="n_layers_hidden", choices=[2, 4, 6, 8]), - CategoricalDistribution(name="dim_hidden", choices=[128, 256, 512, 1024]), + LogDistribution(name="lr", low=1e-5, high=1e-1), + LogIntDistribution(name="batch_size", low=256, high=4096), + IntegerDistribution(name="num_timesteps", choices=[100, 1000]), + LogIntDistribution(name="n_iter", low=1000, high=10000), + IntegerDistribution(name="n_layers_hidden", low=2, high=8), + LogIntDistribution(name="dim_hidden", low=128, high=1024), ] def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": From 52011d30127cc906bc2705991544102db12a5579 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 30 Mar 2023 16:37:46 +0200 Subject: [PATCH 31/95] update DDPM to use TabularEncoder --- src/synthcity/plugins/generic/plugin_ddpm.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index b7f24974..1588253f 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -22,6 +22,7 @@ LogIntDistribution, ) from synthcity.plugins.core.models.tabular_ddpm import TabDDPM +from synthcity.plugins.core.models.tabular_encoder import TabularEncoder from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema from synthcity.utils.callbacks import Callback @@ -200,11 +201,6 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": cond = kwargs.pop("cond", None) self.loss_history = None - # note that the TabularEncoder is not used in this plugin, because the - # Gaussian multinomial diffusion module needs to know the number of classes - # for each discrete feature before it applies torch.nn.functional.one_hot - # on these features, and it also preprocesses the continuous features differently. - if args: raise ValueError("Only keyword arguments are allowed") @@ -219,6 +215,11 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": self.target_name = cond.name self.target_iloc = list(X.columns).index(cond.name) + self.encoder = TabularEncoder( + categorical_encoder="passthrough", continuous_encoder="quantile" + ) + df = self.encoder.fit_transform(df) + if cond is not None: if type(cond) is str: cond = df[cond] @@ -245,6 +246,7 @@ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader def callback(count): # type: ignore data = self.model.generate(count, cond=cond) + data = self.encoder.inverse_transform(data) if self.is_classification: data = np.insert(data, self.target_iloc, cond, axis=1) return data From 0ee6c8b58336dae255be2ce6994fcb4b0634de37 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 31 Mar 2023 00:55:57 +0200 Subject: [PATCH 32/95] update test_tabular_encoder and debug --- .gitignore | 1 - .../plugins/core/models/data_encoder.py | 186 ++++++++++++------ .../core/models/tabular_ddpm/__init__.py | 14 +- .../plugins/core/models/tabular_encoder.py | 54 +++-- src/synthcity/plugins/generic/plugin_ddpm.py | 20 +- .../core/models/test_tabular_encoder.py | 73 +++---- ...al8_tabular_modelling_with_diffusion.ipynb | 62 +++--- 7 files changed, 225 insertions(+), 185 deletions(-) diff --git a/.gitignore b/.gitignore index b2bc0daa..41f36b84 100644 --- a/.gitignore +++ b/.gitignore @@ -67,4 +67,3 @@ lightning_logs generated MNIST cifar-10* -src/test.py diff --git a/src/synthcity/plugins/core/models/data_encoder.py b/src/synthcity/plugins/core/models/data_encoder.py index 90ed7e1e..97cdcebb 100644 --- a/src/synthcity/plugins/core/models/data_encoder.py +++ b/src/synthcity/plugins/core/models/data_encoder.py @@ -1,5 +1,5 @@ # stdlib -from typing import Any, List, Type, Union +from typing import Any, List, Optional, Type, Union # third party import numpy as np @@ -8,51 +8,84 @@ from sklearn.base import BaseEstimator, TransformerMixin from sklearn.mixture import BayesianGaussianMixture from sklearn.preprocessing import ( + LabelEncoder, MinMaxScaler, OneHotEncoder, QuantileTransformer, + RobustScaler, StandardScaler, ) + +def validate_shape(x: np.ndarray, n_dim: int) -> np.ndarray: + if n_dim == 1: + if x.ndim == 2: + x = np.squeeze(x, axis=1) + if x.ndim != 1: + raise ValueError("array must be 1D") + return x + elif n_dim == 2: + if x.ndim == 1: + x = x.reshape(-1, 1) + if x.ndim != 2: + raise ValueError("array must be 2D") + return x + else: + raise ValueError("n_dim must be 1 or 2") + + FeatureEncoder = Any # tried to use ForwardRef but it didn't work under mypy class FeatureEncoder(TransformerMixin, BaseEstimator): # type: ignore - """Base feature encoder, with sklearn-style API""" + """ + Base feature encoder with sklearn-style API. + """ - def __new__(cls, **kwargs: Any) -> FeatureEncoder: - obj = super().__new__() - obj.__dict__.update(kwargs) # auto set all parameters as attributes - return obj + n_dim_in: int = 1 + n_dim_out: int = 2 + n_features_out: int + feature_name_in: str + feature_names_out: List[str] + feature_types_out: List[str] + categorical: bool = False # used by get_feature_types_out + + def __init__( + self, n_dim_in: Optional[int] = None, n_dim_out: Optional[int] = None + ) -> None: + super().__init__() + if n_dim_in is not None: + self.n_dim_in = n_dim_in + if n_dim_out is not None: + self.n_dim_out = n_dim_out @validate_arguments(config=dict(arbitrary_types_allowed=True)) def fit(self, x: pd.Series, y: Any = None, **kwargs: Any) -> FeatureEncoder: self.feature_name_in = x.name - out = self._fit(x, **kwargs)._transform(x) - - if np.ndim(out) == 1: - self.n_features_out = 1 - else: - self.n_features_out = np.shape(out)[1] - + self.feature_type_in = self._get_feature_type(x) + input = validate_shape(x.values, self.n_dim_in) + output = self._fit(input, **kwargs)._transform(input) + self._out_shape = (-1, *output.shape[1:]) # for inverse_transform + output = validate_shape(output, self.n_dim_out) self.feature_names_out = self.get_feature_names_out() - + self.n_features_out = len(self.feature_names_out) + self.feature_types_out = self.get_feature_types_out(output) return self - def _fit(self, x: pd.Series, **kwargs: Any) -> FeatureEncoder: + def _fit(self, x: np.ndarray, **kwargs: Any) -> FeatureEncoder: return self @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def transform(self, x: pd.Series) -> Any: - out = self._transform(x) - if isinstance(out, np.ndarray): - if out.ndim == 1: - return pd.Series(out, self.feature_name_in) - else: - return pd.DataFrame(out, columns=self.feature_names_out) - return out - - def _transform(self, x: pd.Series) -> Any: + def transform(self, x: pd.Series) -> Union[pd.DataFrame, pd.Series]: + data = validate_shape(x.values, self.n_dim_in) + out = self._transform(data) + out = validate_shape(out, self.n_dim_out) + if self.n_dim_out == 1: + return pd.Series(out, name=self.feature_name_in) + else: + return pd.DataFrame(out, columns=self.feature_names_out) + + def _transform(self, x: np.ndarray) -> np.ndarray: return x def get_feature_names_out(self) -> List[str]: @@ -60,61 +93,79 @@ def get_feature_names_out(self) -> List[str]: if n == 1: return [self.feature_name_in] else: - return [self.feature_name_in + str(i) for i in range(n)] + return [f"{self.feature_name_in}_{i}" for i in range(n)] + + def get_feature_types_out(self, output: np.ndarray) -> List[str]: + t = self._get_feature_type(output) + return [t] * self.n_features_out + + def _get_feature_type(self, x: Any) -> str: + if self.categorical: + return "discrete" + elif np.issubdtype(x.dtype, np.floating): + return "continuous" + else: + return "discrete" @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def inverse_transform(self, data: Any) -> pd.Series: - x = self._inverse_transform(data) + def inverse_transform(self, df: Union[pd.DataFrame, pd.Series]) -> pd.Series: + y = df.values.reshape(self._out_shape) + x = self._inverse_transform(y) + x = validate_shape(x, 1) return pd.Series(x, name=self.feature_name_in) - def _inverse_transform(self, data: Any) -> pd.Series: + def _inverse_transform(self, data: np.ndarray) -> np.ndarray: return data @classmethod - def wraps(cls, encoder_class: TransformerMixin) -> Type[FeatureEncoder]: + def wraps( + cls, encoder_class: TransformerMixin, **params: Any + ) -> Type[FeatureEncoder]: """Wraps sklearn transformer to FeatureEncoder.""" class WrappedEncoder(FeatureEncoder): + n_dim_in = 2 # most sklearn transformers accept 2D input + def __init__(self, *args: Any, **kwargs: Any) -> None: self.encoder = encoder_class(*args, **kwargs) - def _fit(self, x: pd.Series, **kwargs: Any) -> FeatureEncoder: + def _fit(self, x: np.ndarray, **kwargs: Any) -> FeatureEncoder: self.encoder.fit(x, **kwargs) return self - def _transform(self, x: pd.Series) -> Any: + def _transform(self, x: np.ndarray) -> np.ndarray: return self.encoder.transform(x) - def _inverse_transform(self, x: pd.Series) -> Any: - return self.encoder.inverse_transform(x) + def _inverse_transform(self, data: np.ndarray) -> np.ndarray: + return self.encoder.inverse_transform(data) def get_feature_names_out(self) -> List[str]: - return self.encoder.get_feature_names_out([self.feature_name_in]) - - for attr in ( - "__module__", - "__name__", - "__qualname__", - "__doc__", - "__annotations__", - ): + return list(self.encoder.get_feature_names_out([self.feature_name_in])) + + for attr in ("__name__", "__qualname__", "__doc__"): setattr(WrappedEncoder, attr, getattr(encoder_class, attr)) + for attr, val in params.items(): + setattr(WrappedEncoder, attr, val) return WrappedEncoder -OneHotEncoder = FeatureEncoder.wraps(OneHotEncoder) +OneHotEncoder = FeatureEncoder.wraps(OneHotEncoder, categorical=True) +LabelEncoder = FeatureEncoder.wraps(LabelEncoder, n_dim_out=1, categorical=True) StandardScaler = FeatureEncoder.wraps(StandardScaler) MinMaxScaler = FeatureEncoder.wraps(MinMaxScaler) +RobustScaler = FeatureEncoder.wraps(RobustScaler) class DatetimeEncoder(FeatureEncoder): """Datetime variables encoder""" - def _transform(self, x: pd.Series) -> pd.Series: + n_dim_out = 1 + + def _transform(self, x: np.ndarray) -> np.ndarray: return pd.to_numeric(x).astype(float) - def _inverse_transform(self, data: pd.Series) -> pd.Series: + def _inverse_transform(self, data: np.ndarray) -> np.ndarray: return pd.to_datetime(data) @@ -129,6 +180,11 @@ def __init__( clip_output: bool = True, std_multiplier: int = 4, ) -> None: + self.n_components = n_components + self.random_state = random_state + self.weight_threshold = weight_threshold + self.clip_output = clip_output + self.std_multiplier = std_multiplier self.model = BayesianGaussianMixture( n_components=n_components, random_state=random_state, @@ -136,19 +192,19 @@ def __init__( ) self.weights: List[float] - def _fit(self, x: pd.Series, **kwargs: Any) -> "BayesianGaussianMixture": + def _fit(self, x: np.ndarray, **kwargs: Any) -> "BayesianGaussianMixture": self.min_value = x.min() self.max_value = x.max() - self.model.fit(x.values.reshape(-1, 1)) + self.model.fit(x.reshape(-1, 1)) self.weights = self.model.weights_ self.means = self.model.means_.reshape(-1) self.stds = np.sqrt(self.model.covariances_).reshape(-1) return self - def _transform(self, x: pd.Series) -> pd.DataFrame: - x = x.values.reshape(-1, 1) + def _transform(self, x: np.ndarray) -> np.ndarray: + x = x.reshape(-1, 1) means = self.means.reshape(1, -1) stds = self.stds.reshape(1, -1) @@ -161,30 +217,32 @@ def _transform(self, x: pd.Series) -> pd.DataFrame: components = np.argmax(component_probs, axis=1) normalized = normalized_values[np.arange(len(x)), components] - if self.clip_output: + if self.clip_output: # why use 0.99 instead of 1? normalized = np.clip(normalized, -0.99, 0.99) normalized = normalized.reshape(-1, 1) - components = np.eye(self.n_components)[components] # onehot + components = np.eye(self.n_components, dtype=int)[components] return np.hstack([normalized, components]) def get_feature_names_out(self) -> List[str]: name = self.feature_name_in return [f"{name}.value"] + [ - f"{name}.component_{i}" for i in range(self.n_features_out - 1) + f"{name}.component_{i}" for i in range(self.n_components) ] - def _inverse_transform(self, data: pd.DataFrame) -> pd.Series: - if self.clip_output: - data = np.clip(data.values[:, 0], -1, 1) + def get_feature_types_out(self, output: np.ndarray) -> List[str]: + return ["continuous"] + ["discrete"] * self.n_components + + def _inverse_transform(self, data: np.ndarray) -> np.ndarray: + components = np.argmax(data[:, 1:], axis=1) - means = self.model.means_.reshape([-1]) - stds = np.sqrt(self.model.covariances_).reshape([-1]) - components = np.argmax(data.values[:, 1:], axis=1) + data = data[:, 0] + if self.clip_output: + data = np.clip(data, -1.0, 1.0) # recreate data - std_t = stds[components] - mean_t = means[components] + mean_t = self.means[components] + std_t = self.stds[components] reversed_data = data * self.std_multiplier * std_t + mean_t # clip values @@ -212,7 +270,7 @@ def __init__( copy=copy, ) - def fit(self, x: pd.Series, y: Any = None) -> "GaussianQuantileTransformer": + def fit(self, x: np.ndarray, y: Any = None) -> "GaussianQuantileTransformer": self.n_quantiles = max(min(len(x) // 30, 1000), 10) return super().fit(x, y) @@ -220,8 +278,10 @@ def fit(self, x: pd.Series, y: Any = None) -> "GaussianQuantileTransformer": ENCODERS = { "datetime": DatetimeEncoder, "onehot": OneHotEncoder, + "label": LabelEncoder, "standard": StandardScaler, "minmax": MinMaxScaler, + "robust": RobustScaler, "quantile": GaussianQuantileTransformer, "bayesian_gmm": BayesianGMMEncoder, "passthrough": FeatureEncoder, @@ -236,9 +296,11 @@ def get_encoder(encoder: Union[str, type]) -> Type[FeatureEncoder]: - datetime - Categorical - onehot + - label - Continuous - standard - minmax + - robust - quantile - bayesian_gmm - Passthrough diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index d6141f81..0332fa03 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -83,19 +83,17 @@ def fit( else: self.n_classes = 0 + self.feature_names = X.columns cat_cols = discrete_columns(X, return_counts=True) if cat_cols: - ini_cols = X.columns cat_cols, cat_counts = zip(*cat_cols) # reorder the columns so that the categorical ones go to the end X = X[np.hstack([X.columns[~X.keys().isin(cat_cols)], cat_cols])] - cur_cols = X.columns - # find the permutation from the reordered columns to the original ones - self._col_perm = np.argsort(cur_cols)[np.argsort(np.argsort(ini_cols))] + self.feature_names_out = X.columns else: cat_counts = [0] - self._col_perm = np.arange(X.shape[1]) + self.feature_names_out = self.feature_names model_params = dict( num_classes=self.n_classes, @@ -207,10 +205,10 @@ def fit( return self - def generate(self, count: int, cond: Any = None) -> np.ndarray: + def generate(self, count: int, cond: Any = None) -> pd.DataFrame: self.diffusion.eval() if cond is not None: cond = torch.tensor(cond, dtype=torch.long, device=self.device) sample = self.diffusion.sample_all(count, cond).detach().cpu().numpy() - sample = sample[:, self._col_perm] - return sample + df = pd.DataFrame(sample, columns=self.feature_names_out) + return df[self.feature_names] diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 0eb2a096..e5f260e3 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -9,6 +9,7 @@ import pandas as pd from pydantic import BaseModel, validate_arguments, validator from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.preprocessing import MinMaxScaler # synthcity absolute import synthcity.logger as log @@ -16,15 +17,16 @@ from synthcity.utils.serialization import dataframe_hash # synthcity relative -from .data_encoder import FeatureEncoder, get_encoder +from .data_encoder import get_encoder class FeatureInfo(BaseModel): name: str feature_type: str - transform: FeatureEncoder + transform: Any output_dimensions: int transformed_features: List[str] + trans_feature_types: List[str] @validator("feature_type") def _feature_type_validator(cls: Any, v: str) -> str: @@ -66,6 +68,7 @@ def __init__( self, *, whitelist: tuple = (), + max_clusters: int = 10, categorical_limit: int = 10, categorical_encoder: Optional[Union[str, type]] = None, continuous_encoder: Optional[Union[str, type]] = None, @@ -78,8 +81,9 @@ def __init__( whitelist (tuple): Columns that will not be transformed. """ - self.categorical_limit = categorical_limit self.whitelist = whitelist + self.categorical_limit = categorical_limit + self.max_clusters = max_clusters # for compatibility if categorical_encoder is not None: self.categorical_encoder = categorical_encoder if continuous_encoder is not None: @@ -92,6 +96,8 @@ def __init__( self.cont_encoder_params = cont_encoder_params else: self.cont_encoder_params = self.cont_encoder_params.copy() + if self.continuous_encoder == "bayesian_gmm": + self.cont_encoder_params["n_components"] = max_clusters @validate_arguments(config=dict(arbitrary_types_allowed=True)) def _fit_feature(self, feature: pd.Series, feature_type: str) -> FeatureInfo: @@ -120,6 +126,7 @@ def _fit_feature(self, feature: pd.Series, feature_type: str) -> FeatureInfo: transform=encoder, output_dimensions=encoder.n_features_out, transformed_features=encoder.feature_names_out, + trans_feature_types=encoder.feature_types_out, ) @validate_arguments(config=dict(arbitrary_types_allowed=True)) @@ -136,7 +143,7 @@ def fit( self.output_dimensions = 0 self._column_raw_dtypes = raw_data.infer_objects().dtypes - self._column_transform_info_list = [] + self._column_transform_info_list: Sequence[FeatureInfo] = [] for name in raw_data.columns: if name in self.whitelist: @@ -233,7 +240,7 @@ def inverse_transform(self, data: pd.DataFrame) -> pd.DataFrame: ).astype(self._column_raw_dtypes.filter(names)) return recovered_data - def layout(self) -> List[Tuple]: + def layout(self) -> Sequence[FeatureInfo]: """Get the layout of the encoded dataset. Returns a list of tuple, describing each column as: @@ -258,7 +265,7 @@ def get_column_info(self, name: str) -> FeatureInfo: @validate_arguments(config=dict(arbitrary_types_allowed=True)) def activation_layout( self, discrete_activation: str, continuous_activation: str - ) -> Sequence[Tuple]: + ) -> Sequence[Tuple[str, int]]: """Get the layout of the activations. Returns a list of tuple, describing each column as: @@ -267,21 +274,9 @@ def activation_layout( """ out = [] for column_transform_info in self._column_transform_info_list: - if column_transform_info.feature_type == "continuous": - out.extend( - [ - (continuous_activation, 1), - ( - discrete_activation, - column_transform_info.output_dimensions - 1, - ), - ] - ) - else: - out.append( - (discrete_activation, column_transform_info.output_dimensions) - ) - + for t in column_transform_info.trans_feature_types: + act = discrete_activation if t == "discrete" else continuous_activation + out.append((act, 1)) return out @@ -305,13 +300,9 @@ def _transform_feature( return super()._transform_feature(column_transform_info, feature) bgm = column_transform_info.transform out = bgm.transform(feature) - if out.shape != (len(feature), 3): - raise ValueError( - "BinEncoder should transform continuous features using a " - "BayesianGMM with 2 components" - ) - # encoded as a binary vector corresponding to the first component - return pd.DataFrame(out.values[:, [1]], columns=[bgm.feature_name_in]) + return pd.DataFrame( + out.values[:, 1:].argmax(axis=1), columns=[bgm.feature_name_in] + ) def _inverse_transform_feature( self, column_transform_info: FeatureInfo, column_data: pd.DataFrame @@ -339,12 +330,10 @@ def __init__( max_clusters: int = 10, categorical_limit: int = 10, whitelist: list = [], - encoder: str = "minmax", ) -> None: self.max_clusters = max_clusters self.categorical_limit = categorical_limit self.whitelist = whitelist - self.encoder = encoder def fit_temporal( self, @@ -369,8 +358,9 @@ def fit_temporal( self.temporal_encoder.fit(temporal_df) # Temporal horizons - self.observation_times_encoder = get_encoder(self.encoder) - self.observation_times_encoder.fit(np.asarray(observation_times).reshape(-1, 1)) + self.observation_times_encoder = MinMaxScaler().fit( + np.asarray(observation_times).reshape(-1, 1) + ) return self diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 1588253f..b1150ac9 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -154,6 +154,12 @@ def __init__( dim_embed=dim_embed, ) + self.encoder = TabularEncoder( + categorical_encoder="passthrough", + continuous_encoder="quantile", + cont_encoder_params=dict(random_state=random_state), + ) + @staticmethod def name() -> str: return "ddpm" @@ -198,6 +204,8 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": If the task is regression, the target variable is not specially treated. There is no condition by default, but can be given by the user, either as a column name or an array-like. """ df = X.dataframe() + self.feature_names = df.columns + cond = kwargs.pop("cond", None) self.loss_history = None @@ -213,11 +221,7 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": self._labels, self._cond_dist = np.unique(cond, return_counts=True) self._cond_dist = self._cond_dist / self._cond_dist.sum() self.target_name = cond.name - self.target_iloc = list(X.columns).index(cond.name) - self.encoder = TabularEncoder( - categorical_encoder="passthrough", continuous_encoder="quantile" - ) df = self.encoder.fit_transform(df) if cond is not None: @@ -245,11 +249,11 @@ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader raise ValueError("The length of cond is less than the required count") def callback(count): # type: ignore - data = self.model.generate(count, cond=cond) - data = self.encoder.inverse_transform(data) + df = self.model.generate(count, cond=cond) + df = self.encoder.inverse_transform(df) if self.is_classification: - data = np.insert(data, self.target_iloc, cond, axis=1) - return data + df = df.join(pd.Series(cond, name=self.target_name)) + return df[self.feature_names] # reorder columns return self._safe_generate(callback, count, syn_schema, **kwargs) diff --git a/tests/plugins/core/models/test_tabular_encoder.py b/tests/plugins/core/models/test_tabular_encoder.py index 948399b5..6837190a 100644 --- a/tests/plugins/core/models/test_tabular_encoder.py +++ b/tests/plugins/core/models/test_tabular_encoder.py @@ -78,7 +78,6 @@ def test_encoder_fit_transform(max_clusters: int) -> None: assert set(encoded[f"{column.name}_{val}"].unique()).issubset( set([0, 1]) ) - else: assert f"{column.name}.value" in encoded.columns for enc_col in encoded.columns: @@ -102,6 +101,27 @@ def test_encoder_inverse_transform(max_clusters: int) -> None: assert np.abs(X - recovered).sum().sum() < 5 +def check_equal_layouts( + layout: list, act_layout: list, disc_act: str, cont_act: str +) -> None: + expected_act_layout = [] + for col_info in layout: + if col_info.feature_type == "continuous": + expected_act_layout.append(cont_act) + for _ in range(col_info.output_dimensions - 1): + expected_act_layout.append(disc_act) + else: + for _ in range(col_info.output_dimensions): + expected_act_layout.append(disc_act) + + expanded_act_layout = [] + for act, num in act_layout: + for _ in range(num): + expanded_act_layout.append(act) + + assert expanded_act_layout == expected_act_layout + + def test_encoder_activation_layout() -> None: X, _ = load_diabetes(return_X_y=True, as_frame=True) net = TabularEncoder() @@ -113,20 +133,7 @@ def test_encoder_activation_layout() -> None: layout = net.layout() assert len(layout) <= len(act_layout) - - act_step = 0 - - for col_info in layout: - if col_info.feature_type == "continuous": - assert act_layout[act_step] == ("tanh", 1) - assert act_layout[act_step + 1] == ( - "softmax", - col_info.output_dimensions - 1, - ) - act_step += 2 - else: - assert act_layout[act_step] == ("softmax", col_info.output_dimensions) - act_step += 1 + check_equal_layouts(layout, act_layout, "softmax", "tanh") def test_bin_encoder() -> None: @@ -138,6 +145,8 @@ def test_bin_encoder() -> None: binned = net.transform(X) for col in X.columns: + # ! the target column is transformed by OneHotEncoder to target_0, target_1, target_2 + # ! will result in a KeyError assert len(binned[col].unique()) <= 10 @@ -272,35 +281,5 @@ def test_ts_encoder_activation_layout(source: Any) -> None: assert len(static_layout) <= len(static_act_layout) assert len(temporal_layout) <= len(temporal_act_layout) - - act_step = 0 - for col_info in static_layout: - if col_info.feature_type == "continuous": - assert static_act_layout[act_step] == ("tanh", 1) - assert static_act_layout[act_step + 1] == ( - "softmax", - col_info.output_dimensions - 1, - ) - act_step += 2 - else: - assert static_act_layout[act_step] == ( - "softmax", - col_info.output_dimensions, - ) - act_step += 1 - - act_step = 0 - for col_info in temporal_layout: - if col_info.feature_type == "continuous": - assert temporal_act_layout[act_step] == ("tanh", 1) - assert temporal_act_layout[act_step + 1] == ( - "softmax", - col_info.output_dimensions - 1, - ) - act_step += 2 - else: - assert temporal_act_layout[act_step] == ( - "softmax", - col_info.output_dimensions, - ) - act_step += 1 + check_equal_layouts(static_layout, static_act_layout, "softmax", "tanh") + check_equal_layouts(temporal_layout, temporal_act_layout, "softmax", "tanh") diff --git a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb index b520308e..197b5737 100644 --- a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb +++ b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb @@ -65,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 2, "id": "51076cdc", "metadata": {}, "outputs": [ @@ -158,7 +158,7 @@ "4 0 " ] }, - "execution_count": 12, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -176,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 3, "id": "52397e4a", "metadata": {}, "outputs": [ @@ -189,7 +189,7 @@ "Name: target, dtype: int64" ] }, - "execution_count": 13, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -209,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 4, "id": "3bf24be4", "metadata": {}, "outputs": [ @@ -217,27 +217,35 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2023-03-27T15:19:24.516935+0200][30696][INFO] Step 100: MLoss: 0.0 GLoss: 0.2235 Sum: 0.2235\n", - "[2023-03-27T15:19:25.913968+0200][30696][INFO] Step 200: MLoss: 0.0 GLoss: 0.2298 Sum: 0.2298\n", - "[2023-03-27T15:19:27.191123+0200][30696][INFO] Step 300: MLoss: 0.0 GLoss: 0.2305 Sum: 0.2305\n", - "[2023-03-27T15:19:28.432055+0200][30696][INFO] Step 400: MLoss: 0.0 GLoss: 0.2273 Sum: 0.2273\n", - "[2023-03-27T15:19:29.766838+0200][30696][INFO] Step 500: MLoss: 0.0 GLoss: 0.2333 Sum: 0.2333\n", - "[2023-03-27T15:19:31.280538+0200][30696][INFO] Step 600: MLoss: 0.0 GLoss: 0.221 Sum: 0.221\n", - "[2023-03-27T15:19:33.034999+0200][30696][INFO] Step 700: MLoss: 0.0 GLoss: 0.2123 Sum: 0.2123\n", - "[2023-03-27T15:19:34.519078+0200][30696][INFO] Step 800: MLoss: 0.0 GLoss: 0.2212 Sum: 0.2212\n", - "[2023-03-27T15:19:36.020932+0200][30696][INFO] Step 900: MLoss: 0.0 GLoss: 0.2014 Sum: 0.2014\n", - "[2023-03-27T15:19:38.330664+0200][30696][INFO] Step 1000: MLoss: 0.0 GLoss: 0.2069 Sum: 0.2069\n" + "[2023-03-31T00:29:29.172830+0200][10148][INFO] Encoding sepal length (cm) 8461685668942494555\n" ] }, { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" + "ename": "TypeError", + "evalue": "__init__() got an unexpected keyword argument 'n_components'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[0mplugin\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mPlugins\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"ddpm\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mplugin_params\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 20\u001b[1;33m \u001b[0mplugin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\plugin.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[0;32m 242\u001b[0m )\n\u001b[0;32m 243\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 244\u001b[1;33m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_fit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 245\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfitted\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 246\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_ddpm.py\u001b[0m in \u001b[0;36m_fit\u001b[1;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[0;32m 221\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 222\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 223\u001b[1;33m \u001b[0mdf\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit_transform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 224\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 225\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcond\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\sklearn\\utils\\_set_output.py\u001b[0m in \u001b[0;36mwrapped\u001b[1;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[0;32m 140\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 141\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 142\u001b[1;33m \u001b[0mdata_to_wrap\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 143\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata_to_wrap\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 144\u001b[0m \u001b[1;31m# only wrap the first output for cross decomposition\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\sklearn\\base.py\u001b[0m in \u001b[0;36mfit_transform\u001b[1;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[0;32m 857\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0my\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 858\u001b[0m \u001b[1;31m# fit method of arity 1 (unsupervised transformation)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 859\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 860\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 861\u001b[0m \u001b[1;31m# fit method of arity 2 (supervised transformation)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_encoder.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, raw_data, discrete_columns)\u001b[0m\n\u001b[0;32m 155\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 156\u001b[0m \u001b[0mftype\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m\"continuous\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 157\u001b[1;33m \u001b[0mcolumn_transform_info\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_fit_feature\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mraw_data\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mftype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 158\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 159\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutput_dimensions\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mcolumn_transform_info\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutput_dimensions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", + "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_encoder.py\u001b[0m in \u001b[0;36m_fit_feature\u001b[1;34m(self, feature, feature_type)\u001b[0m\n\u001b[0;32m 117\u001b[0m \u001b[0mencoder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_encoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcategorical_encoder\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat_encoder_params\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 118\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 119\u001b[1;33m \u001b[0mencoder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_encoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontinuous_encoder\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcont_encoder_params\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[0mencoder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfeature\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\data_encoder.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 122\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 123\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 124\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mencoder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mencoder_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 125\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 126\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_fit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mFeatureEncoder\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'n_components'" + ] } ], "source": [ @@ -265,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "e1a270c9", "metadata": {}, "outputs": [ @@ -361,7 +369,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "49b18ada", "metadata": {}, "outputs": [ @@ -406,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "a2e81779", "metadata": {}, "outputs": [ @@ -569,7 +577,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "1f55ffdb", "metadata": {}, "outputs": [ From 244854d7c0c6137f5de9f84436644875e8d65756 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 31 Mar 2023 01:11:44 +0200 Subject: [PATCH 33/95] debug and DDPM tutorial OK --- .../plugins/core/models/data_encoder.py | 2 +- src/synthcity/plugins/generic/plugin_ddpm.py | 3 +- ...al8_tabular_modelling_with_diffusion.ipynb | 1345 +++++++---------- 3 files changed, 508 insertions(+), 842 deletions(-) diff --git a/src/synthcity/plugins/core/models/data_encoder.py b/src/synthcity/plugins/core/models/data_encoder.py index 97cdcebb..518400fa 100644 --- a/src/synthcity/plugins/core/models/data_encoder.py +++ b/src/synthcity/plugins/core/models/data_encoder.py @@ -67,8 +67,8 @@ def fit(self, x: pd.Series, y: Any = None, **kwargs: Any) -> FeatureEncoder: output = self._fit(input, **kwargs)._transform(input) self._out_shape = (-1, *output.shape[1:]) # for inverse_transform output = validate_shape(output, self.n_dim_out) + self.n_features_out = output.shape[1] self.feature_names_out = self.get_feature_names_out() - self.n_features_out = len(self.feature_names_out) self.feature_types_out = self.get_feature_types_out(output) return self diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index b1150ac9..9ac18878 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -155,9 +155,10 @@ def __init__( ) self.encoder = TabularEncoder( - categorical_encoder="passthrough", continuous_encoder="quantile", + categorical_encoder="passthrough", cont_encoder_params=dict(random_state=random_state), + cat_encoder_params=dict(), ) @staticmethod diff --git a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb index 197b5737..d73d0f60 100644 --- a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb +++ b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb @@ -217,35 +217,31 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2023-03-31T00:29:29.172830+0200][10148][INFO] Encoding sepal length (cm) 8461685668942494555\n" + "[2023-03-31T01:04:28.062034+0200][12004][INFO] Encoding sepal length (cm) 8461685668942494555\n", + "[2023-03-31T01:04:28.068034+0200][12004][INFO] Encoding sepal width (cm) 7372477013158199918\n", + "[2023-03-31T01:04:28.074037+0200][12004][INFO] Encoding petal length (cm) 8795408021141068254\n", + "[2023-03-31T01:04:28.081036+0200][12004][INFO] Encoding petal width (cm) 1839870727438321343\n", + "[2023-03-31T01:04:29.905425+0200][12004][INFO] Step 100: MLoss: 0.0 GLoss: 0.3103 Sum: 0.3103\n", + "[2023-03-31T01:04:31.486761+0200][12004][INFO] Step 200: MLoss: 0.0 GLoss: 0.3111 Sum: 0.3111\n", + "[2023-03-31T01:04:33.076905+0200][12004][INFO] Step 300: MLoss: 0.0 GLoss: 0.317 Sum: 0.317\n", + "[2023-03-31T01:04:34.611746+0200][12004][INFO] Step 400: MLoss: 0.0 GLoss: 0.3009 Sum: 0.3009\n", + "[2023-03-31T01:04:36.176039+0200][12004][INFO] Step 500: MLoss: 0.0 GLoss: 0.3154 Sum: 0.3154\n", + "[2023-03-31T01:04:37.956754+0200][12004][INFO] Step 600: MLoss: 0.0 GLoss: 0.3055 Sum: 0.3055\n", + "[2023-03-31T01:04:39.561269+0200][12004][INFO] Step 700: MLoss: 0.0 GLoss: 0.2917 Sum: 0.2917\n", + "[2023-03-31T01:04:41.195544+0200][12004][INFO] Step 800: MLoss: 0.0 GLoss: 0.2817 Sum: 0.2817\n", + "[2023-03-31T01:04:42.967236+0200][12004][INFO] Step 900: MLoss: 0.0 GLoss: 0.266 Sum: 0.266\n", + "[2023-03-31T01:04:44.913448+0200][12004][INFO] Step 1000: MLoss: 0.0 GLoss: 0.2793 Sum: 0.2793\n" ] }, { - "ename": "TypeError", - "evalue": "__init__() got an unexpected keyword argument 'n_components'", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[0mplugin\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mPlugins\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"ddpm\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mplugin_params\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 20\u001b[1;33m \u001b[0mplugin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\plugin.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[0;32m 242\u001b[0m )\n\u001b[0;32m 243\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 244\u001b[1;33m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_fit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 245\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfitted\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 246\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_ddpm.py\u001b[0m in \u001b[0;36m_fit\u001b[1;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[0;32m 221\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 222\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 223\u001b[1;33m \u001b[0mdf\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit_transform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 224\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 225\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcond\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\sklearn\\utils\\_set_output.py\u001b[0m in \u001b[0;36mwrapped\u001b[1;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[0;32m 140\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 141\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 142\u001b[1;33m \u001b[0mdata_to_wrap\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 143\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata_to_wrap\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 144\u001b[0m \u001b[1;31m# only wrap the first output for cross decomposition\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\sklearn\\base.py\u001b[0m in \u001b[0;36mfit_transform\u001b[1;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[0;32m 857\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0my\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 858\u001b[0m \u001b[1;31m# fit method of arity 1 (unsupervised transformation)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 859\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 860\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 861\u001b[0m \u001b[1;31m# fit method of arity 2 (supervised transformation)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_encoder.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, raw_data, discrete_columns)\u001b[0m\n\u001b[0;32m 155\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 156\u001b[0m \u001b[0mftype\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m\"continuous\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 157\u001b[1;33m \u001b[0mcolumn_transform_info\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_fit_feature\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mraw_data\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mftype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 158\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 159\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutput_dimensions\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mcolumn_transform_info\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutput_dimensions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_encoder.py\u001b[0m in \u001b[0;36m_fit_feature\u001b[1;34m(self, feature, feature_type)\u001b[0m\n\u001b[0;32m 117\u001b[0m \u001b[0mencoder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_encoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcategorical_encoder\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat_encoder_params\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 118\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 119\u001b[1;33m \u001b[0mencoder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_encoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontinuous_encoder\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcont_encoder_params\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[0mencoder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfeature\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mD:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\data_encoder.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 122\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 123\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 124\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mencoder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mencoder_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 125\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 126\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_fit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mFeatureEncoder\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'n_components'" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -273,7 +269,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "e1a270c9", "metadata": {}, "outputs": [ @@ -358,7 +354,7 @@ ")" ] }, - "execution_count": 15, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -369,7 +365,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "49b18ada", "metadata": {}, "outputs": [ @@ -385,7 +381,7 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -414,7 +410,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "a2e81779", "metadata": {}, "outputs": [ @@ -449,82 +445,82 @@ " \n", " \n", " 0\n", - " 6.442119\n", - " 2.934733\n", - " 4.326933\n", - " 1.372570\n", + " 6.491386\n", + " 2.937301\n", + " 4.396537\n", + " 1.363964\n", " 1\n", " \n", " \n", " 1\n", - " 6.285412\n", - " 2.721440\n", - " 5.120901\n", - " 2.057547\n", + " 6.272807\n", + " 2.878930\n", + " 5.028617\n", + " 1.973149\n", " 2\n", " \n", " \n", " 2\n", - " 4.696350\n", - " 2.042726\n", - " 2.856909\n", - " 0.788935\n", + " 4.912787\n", + " 2.239502\n", + " 2.384605\n", + " 0.845205\n", " 1\n", " \n", " \n", " 3\n", - " 5.336019\n", - " 2.688533\n", - " 4.163283\n", - " 1.192051\n", + " 5.115768\n", + " 2.636920\n", + " 3.933653\n", + " 1.100583\n", " 1\n", " \n", " \n", " 4\n", - " 6.081825\n", - " 3.221682\n", - " 4.645768\n", - " 1.505293\n", + " 5.946947\n", + " 2.976103\n", + " 4.557983\n", + " 1.417799\n", " 1\n", " \n", " \n", " 5\n", - " 5.690165\n", - " 2.336088\n", - " 4.105630\n", - " 1.296607\n", + " 5.528565\n", + " 2.197114\n", + " 4.133016\n", + " 1.296019\n", " 1\n", " \n", " \n", " 6\n", - " 5.398935\n", - " 2.757713\n", - " 3.809984\n", - " 1.161369\n", + " 5.275113\n", + " 2.565652\n", + " 3.698843\n", + " 1.068934\n", " 1\n", " \n", " \n", " 7\n", - " 7.358270\n", - " 3.283428\n", - " 6.496590\n", - " 2.317238\n", + " 7.900000\n", + " 4.400000\n", + " 6.899995\n", + " 2.500000\n", " 2\n", " \n", " \n", " 8\n", - " 6.595327\n", - " 2.598526\n", - " 5.805653\n", - " 1.451353\n", + " 6.899334\n", + " 2.847685\n", + " 6.243627\n", + " 1.561012\n", " 2\n", " \n", " \n", " 9\n", - " 5.224718\n", - " 2.796224\n", - " 3.500915\n", - " 1.125248\n", + " 5.267148\n", + " 2.780006\n", + " 3.565531\n", + " 1.128439\n", " 1\n", " \n", " \n", @@ -533,16 +529,16 @@ ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", - "0 6.442119 2.934733 4.326933 1.372570 \n", - "1 6.285412 2.721440 5.120901 2.057547 \n", - "2 4.696350 2.042726 2.856909 0.788935 \n", - "3 5.336019 2.688533 4.163283 1.192051 \n", - "4 6.081825 3.221682 4.645768 1.505293 \n", - "5 5.690165 2.336088 4.105630 1.296607 \n", - "6 5.398935 2.757713 3.809984 1.161369 \n", - "7 7.358270 3.283428 6.496590 2.317238 \n", - "8 6.595327 2.598526 5.805653 1.451353 \n", - "9 5.224718 2.796224 3.500915 1.125248 \n", + "0 6.491386 2.937301 4.396537 1.363964 \n", + "1 6.272807 2.878930 5.028617 1.973149 \n", + "2 4.912787 2.239502 2.384605 0.845205 \n", + "3 5.115768 2.636920 3.933653 1.100583 \n", + "4 5.946947 2.976103 4.557983 1.417799 \n", + "5 5.528565 2.197114 4.133016 1.296019 \n", + "6 5.275113 2.565652 3.698843 1.068934 \n", + "7 7.900000 4.400000 6.899995 2.500000 \n", + "8 6.899334 2.847685 6.243627 1.561012 \n", + "9 5.267148 2.780006 3.565531 1.128439 \n", "\n", " target \n", "0 1 \n", @@ -577,7 +573,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "1f55ffdb", "metadata": {}, "outputs": [ @@ -612,74 +608,74 @@ " \n", " \n", " 0\n", - " 5.200935\n", - " 3.410448\n", - " 1.294404\n", - " 0.250156\n", + " 5.230361\n", + " 3.371515\n", + " 1.408195\n", + " 0.201252\n", " 0\n", " \n", " \n", " 1\n", - " 4.892172\n", - " 3.404765\n", - " 1.373966\n", - " 0.317662\n", + " 4.705658\n", + " 3.064075\n", + " 1.388975\n", + " 0.386298\n", " 0\n", " \n", " \n", " 2\n", - " 4.546415\n", - " 3.001362\n", - " 1.379267\n", - " 0.146012\n", + " 4.711709\n", + " 3.056369\n", + " 1.451635\n", + " 0.195365\n", " 0\n", " \n", " \n", " 3\n", - " 6.912333\n", - " 3.372478\n", - " 4.732009\n", - " 1.638499\n", + " 6.981074\n", + " 3.274333\n", + " 4.803886\n", + " 1.623058\n", " 1\n", " \n", " \n", " 4\n", - " 5.479260\n", - " 2.623246\n", - " 3.496161\n", - " 1.265118\n", + " 5.999308\n", + " 2.927207\n", + " 4.040594\n", + " 1.389657\n", " 1\n", " \n", " \n", " 5\n", - " 5.691610\n", - " 2.568420\n", - " 3.620842\n", - " 1.025988\n", + " 5.698102\n", + " 2.521559\n", + " 3.288451\n", + " 0.966808\n", " 1\n", " \n", " \n", " 6\n", - " 6.935314\n", - " 3.246951\n", - " 6.209702\n", - " 2.236808\n", + " 6.776549\n", + " 3.012238\n", + " 6.285867\n", + " 2.134174\n", " 2\n", " \n", " \n", " 7\n", - " 7.082495\n", - " 3.061208\n", - " 5.907195\n", - " 1.950721\n", + " 7.900000\n", + " 4.400000\n", + " 6.896603\n", + " 2.500000\n", " 2\n", " \n", " \n", " 8\n", - " 6.066010\n", - " 2.553123\n", - " 5.193090\n", - " 1.639034\n", + " 7.900000\n", + " 4.400000\n", + " 6.898989\n", + " 2.500000\n", " 2\n", " \n", " \n", @@ -688,15 +684,15 @@ ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", - "0 5.200935 3.410448 1.294404 0.250156 \n", - "1 4.892172 3.404765 1.373966 0.317662 \n", - "2 4.546415 3.001362 1.379267 0.146012 \n", - "3 6.912333 3.372478 4.732009 1.638499 \n", - "4 5.479260 2.623246 3.496161 1.265118 \n", - "5 5.691610 2.568420 3.620842 1.025988 \n", - "6 6.935314 3.246951 6.209702 2.236808 \n", - "7 7.082495 3.061208 5.907195 1.950721 \n", - "8 6.066010 2.553123 5.193090 1.639034 \n", + "0 5.230361 3.371515 1.408195 0.201252 \n", + "1 4.705658 3.064075 1.388975 0.386298 \n", + "2 4.711709 3.056369 1.451635 0.195365 \n", + "3 6.981074 3.274333 4.803886 1.623058 \n", + "4 5.999308 2.927207 4.040594 1.389657 \n", + "5 5.698102 2.521559 3.288451 0.966808 \n", + "6 6.776549 3.012238 6.285867 2.134174 \n", + "7 7.900000 4.400000 6.896603 2.500000 \n", + "8 7.900000 4.400000 6.898989 2.500000 \n", "\n", " target \n", "0 0 \n", @@ -733,7 +729,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 9, "id": "13df0848", "metadata": {}, "outputs": [ @@ -929,7 +925,7 @@ "max 3.820000 1.080000 14.200000 9.000000 " ] }, - "execution_count": 29, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -945,7 +941,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 10, "id": "14bca1cd", "metadata": {}, "outputs": [ @@ -953,35 +949,47 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2023-03-27T18:08:18.761007+0200][38480][INFO] Step 100: MLoss: 1.2836 GLoss: 0.9867 Sum: 2.2703\n", - "[2023-03-27T18:08:24.679745+0200][38480][INFO] Step 200: MLoss: 1.2622 GLoss: 0.9409 Sum: 2.2031\n", - "[2023-03-27T18:08:30.391531+0200][38480][INFO] Step 300: MLoss: 1.2059 GLoss: 0.7669 Sum: 1.9727999999999999\n", - "[2023-03-27T18:08:36.164268+0200][38480][INFO] Step 400: MLoss: 1.1645 GLoss: 0.6393 Sum: 1.8038\n", - "[2023-03-27T18:08:41.835318+0200][38480][INFO] Step 500: MLoss: 1.1717 GLoss: 0.6158 Sum: 1.7875\n", - "[2023-03-27T18:08:47.581383+0200][38480][INFO] Step 600: MLoss: 1.1946 GLoss: 0.5384 Sum: 1.733\n", - "[2023-03-27T18:08:53.378127+0200][38480][INFO] Step 700: MLoss: 1.1343 GLoss: 0.5135 Sum: 1.6478000000000002\n", - "[2023-03-27T18:08:59.698145+0200][38480][INFO] Step 800: MLoss: 1.1168 GLoss: 0.4788 Sum: 1.5956000000000001\n", - "[2023-03-27T18:09:05.752638+0200][38480][INFO] Step 900: MLoss: 1.1034 GLoss: 0.4734 Sum: 1.5768\n", - "[2023-03-27T18:09:12.070003+0200][38480][INFO] Step 1000: MLoss: 1.142 GLoss: 0.4692 Sum: 1.6112\n", - "[2023-03-27T18:09:18.112377+0200][38480][INFO] Step 1100: MLoss: 1.1691 GLoss: 0.4602 Sum: 1.6293\n", - "[2023-03-27T18:09:25.549484+0200][38480][INFO] Step 1200: MLoss: 1.1201 GLoss: 0.4578 Sum: 1.5779\n", - "[2023-03-27T18:09:31.574874+0200][38480][INFO] Step 1300: MLoss: 1.1436 GLoss: 0.4429 Sum: 1.5865\n", - "[2023-03-27T18:09:37.672797+0200][38480][INFO] Step 1400: MLoss: 1.1093 GLoss: 0.449 Sum: 1.5583\n", - "[2023-03-27T18:09:44.149652+0200][38480][INFO] Step 1500: MLoss: 1.1468 GLoss: 0.4347 Sum: 1.5815000000000001\n", - "[2023-03-27T18:09:49.923915+0200][38480][INFO] Step 1600: MLoss: 1.1545 GLoss: 0.4313 Sum: 1.5858\n", - "[2023-03-27T18:09:55.733558+0200][38480][INFO] Step 1700: MLoss: 1.102 GLoss: 0.4305 Sum: 1.5325000000000002\n", - "[2023-03-27T18:10:03.367053+0200][38480][INFO] Step 1800: MLoss: 1.0953 GLoss: 0.4267 Sum: 1.522\n", - "[2023-03-27T18:10:10.533359+0200][38480][INFO] Step 1900: MLoss: 1.1247 GLoss: 0.4223 Sum: 1.5470000000000002\n", - "[2023-03-27T18:10:17.355705+0200][38480][INFO] Step 2000: MLoss: 1.2767 GLoss: 0.4266 Sum: 1.7033\n" + "[2023-03-31T01:04:50.377220+0200][12004][INFO] Encoding fixed acidity 8821222230854998919\n", + "[2023-03-31T01:04:50.427480+0200][12004][INFO] Encoding volatile acidity 3689048099044143611\n", + "[2023-03-31T01:04:50.442050+0200][12004][INFO] Encoding citric acid 735380040632581265\n", + "[2023-03-31T01:04:50.457233+0200][12004][INFO] Encoding residual sugar 2442409671939919968\n", + "[2023-03-31T01:04:50.473234+0200][12004][INFO] Encoding chlorides 7195838597182208600\n", + "[2023-03-31T01:04:50.488234+0200][12004][INFO] Encoding free sulfur dioxide 3309873879720413309\n", + "[2023-03-31T01:04:50.501098+0200][12004][INFO] Encoding total sulfur dioxide 8059822526963442530\n", + "[2023-03-31T01:04:50.512236+0200][12004][INFO] Encoding density 3625281346475756911\n", + "[2023-03-31T01:04:50.523222+0200][12004][INFO] Encoding pH 4552002723230490789\n", + "[2023-03-31T01:04:50.532220+0200][12004][INFO] Encoding sulphates 4957484118723629481\n", + "[2023-03-31T01:04:50.540983+0200][12004][INFO] Encoding alcohol 3711001505059098944\n", + "[2023-03-31T01:04:50.547987+0200][12004][INFO] Encoding quality 3457201635469827215\n", + "[2023-03-31T01:04:58.399971+0200][12004][INFO] Step 100: MLoss: 1.3342 GLoss: 0.9783 Sum: 2.3125\n", + "[2023-03-31T01:05:04.973385+0200][12004][INFO] Step 200: MLoss: 1.2858 GLoss: 0.9031 Sum: 2.1889000000000003\n", + "[2023-03-31T01:05:11.741000+0200][12004][INFO] Step 300: MLoss: 1.186 GLoss: 0.7758 Sum: 1.9618\n", + "[2023-03-31T01:05:18.619270+0200][12004][INFO] Step 400: MLoss: 1.1481 GLoss: 0.6615 Sum: 1.8095999999999999\n", + "[2023-03-31T01:05:24.930108+0200][12004][INFO] Step 500: MLoss: 1.1661 GLoss: 0.6094 Sum: 1.7755\n", + "[2023-03-31T01:05:31.651906+0200][12004][INFO] Step 600: MLoss: 1.1902 GLoss: 0.5381 Sum: 1.7283\n", + "[2023-03-31T01:05:38.246164+0200][12004][INFO] Step 700: MLoss: 1.1305 GLoss: 0.5087 Sum: 1.6392000000000002\n", + "[2023-03-31T01:05:44.776216+0200][12004][INFO] Step 800: MLoss: 1.1131 GLoss: 0.4832 Sum: 1.5963\n", + "[2023-03-31T01:05:51.917105+0200][12004][INFO] Step 900: MLoss: 1.1014 GLoss: 0.4786 Sum: 1.58\n", + "[2023-03-31T01:05:59.098745+0200][12004][INFO] Step 1000: MLoss: 1.1479 GLoss: 0.4707 Sum: 1.6185999999999998\n", + "[2023-03-31T01:06:05.690366+0200][12004][INFO] Step 1100: MLoss: 1.1712 GLoss: 0.4693 Sum: 1.6405\n", + "[2023-03-31T01:06:12.549553+0200][12004][INFO] Step 1200: MLoss: 1.1199 GLoss: 0.4611 Sum: 1.581\n", + "[2023-03-31T01:06:19.575478+0200][12004][INFO] Step 1300: MLoss: 1.1525 GLoss: 0.4614 Sum: 1.6139000000000001\n", + "[2023-03-31T01:06:26.641319+0200][12004][INFO] Step 1400: MLoss: 1.1164 GLoss: 0.4671 Sum: 1.5835000000000001\n", + "[2023-03-31T01:06:33.249503+0200][12004][INFO] Step 1500: MLoss: 1.1356 GLoss: 0.4577 Sum: 1.5933\n", + "[2023-03-31T01:06:40.025759+0200][12004][INFO] Step 1600: MLoss: 1.1367 GLoss: 0.4541 Sum: 1.5908\n", + "[2023-03-31T01:06:46.754777+0200][12004][INFO] Step 1700: MLoss: 1.0896 GLoss: 0.4524 Sum: 1.5419999999999998\n", + "[2023-03-31T01:06:54.036939+0200][12004][INFO] Step 1800: MLoss: 1.075 GLoss: 0.4471 Sum: 1.5221\n", + "[2023-03-31T01:07:00.554405+0200][12004][INFO] Step 1900: MLoss: 1.1154 GLoss: 0.4495 Sum: 1.5649\n", + "[2023-03-31T01:07:07.289610+0200][12004][INFO] Step 2000: MLoss: 1.266 GLoss: 0.454 Sum: 1.72\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 47, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -1004,7 +1012,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 11, "id": "83064f94", "metadata": {}, "outputs": [ @@ -1014,13 +1022,13 @@ "" ] }, - "execution_count": 48, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -1037,7 +1045,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 12, "id": "af9d6df1", "metadata": {}, "outputs": [ @@ -1079,123 +1087,123 @@ " \n", " \n", " 0\n", - " 53.753993\n", - " -2.475239\n", - " 0.404968\n", - " 406.898386\n", - " 1.788962\n", - " 450.073724\n", - " 1221.551077\n", - " 5.697177\n", - " 12.451377\n", - " 14.835445\n", - " 81.515696\n", - " 0.0\n", + " 7.400000\n", + " 0.080079\n", + " 0.000000\n", + " 65.800000\n", + " 0.025529\n", + " 67.119106\n", + " 440.000000\n", + " 1.038980\n", + " 2.720000\n", + " 1.080000\n", + " 8.0\n", + " 7\n", " \n", " \n", " 1\n", - " 241.769932\n", - " -33.905933\n", - " 7.440188\n", - " 1722.536053\n", - " 48.034556\n", - " 1312.264457\n", - " 2141.119368\n", - " 31.259074\n", - " 83.654749\n", - " 28.591759\n", - " 489.172674\n", - " 3.0\n", + " 5.088797\n", + " 0.112499\n", + " 0.370000\n", + " 0.763677\n", + " 0.009000\n", + " 288.824821\n", + " 98.000000\n", + " 0.987110\n", + " 3.240000\n", + " 0.220000\n", + " 14.2\n", + " 8\n", " \n", " \n", " 2\n", - " 25.344904\n", - " 0.769463\n", - " -11.237007\n", - " -335.794326\n", - " -3.595284\n", - " -234.179124\n", - " 382.907515\n", - " 7.637684\n", - " 17.748300\n", - " 3.380296\n", - " 73.701048\n", - " 1.0\n", + " 3.800000\n", + " 1.100000\n", + " 0.000000\n", + " 0.600000\n", + " 0.009000\n", + " 2.000000\n", + " 9.000000\n", + " 1.038980\n", + " 3.820000\n", + " 0.220000\n", + " 8.0\n", + " 4\n", " \n", " \n", " 3\n", - " 15.635557\n", - " -28.371864\n", - " -19.808469\n", - " 800.088446\n", - " 61.404066\n", - " -596.053591\n", - " -1749.797505\n", - " 28.376345\n", - " -71.868790\n", - " -14.556346\n", - " -38.315179\n", - " 1.0\n", + " 3.800000\n", + " 0.080000\n", + " 1.659603\n", + " 0.600000\n", + " 0.034734\n", + " 2.000000\n", + " 9.000000\n", + " 0.987110\n", + " 3.775879\n", + " 1.080000\n", + " 9.5\n", + " 7\n", " \n", " \n", " 4\n", - " -0.796959\n", - " -8.546869\n", - " -4.726590\n", - " 128.343028\n", - " 1.083628\n", - " -288.352104\n", - " 1184.680273\n", - " 8.081500\n", - " 23.012828\n", - " 2.168597\n", - " 36.672840\n", - " 0.0\n", + " 5.700000\n", + " 0.330000\n", + " 0.213874\n", + " 10.937306\n", + " 0.050000\n", + " 39.064968\n", + " 147.790987\n", + " 0.997247\n", + " 3.330984\n", + " 0.380000\n", + " 8.7\n", + " 6\n", " \n", " \n", " 5\n", - " -31.203381\n", - " -39.052177\n", - " -57.651032\n", - " 1269.158981\n", - " -22.793850\n", - " 101.490751\n", - " -661.997823\n", - " 5.012738\n", - " 19.615822\n", - " 26.791456\n", - " -63.773678\n", - " 3.0\n", + " 14.200000\n", + " 0.080000\n", + " 0.000000\n", + " 0.600000\n", + " 0.009000\n", + " 2.000000\n", + " 9.000000\n", + " 0.987110\n", + " 2.916428\n", + " 0.220055\n", + " 9.5\n", + " 5\n", " \n", " \n", " 6\n", - " -120.526480\n", - " -49.314650\n", - " -67.642982\n", - " 650.136816\n", - " 65.155843\n", - " 598.106999\n", - " -3468.753037\n", - " 3.750566\n", - " 52.556860\n", - " -108.310847\n", - " -91.816310\n", - " 3.0\n", + " 14.200000\n", + " 0.087887\n", + " 1.660000\n", + " 0.600000\n", + " 0.108297\n", + " 49.000000\n", + " 65.466909\n", + " 0.987117\n", + " 2.720090\n", + " 0.220006\n", + " 9.0\n", + " 6\n", " \n", " \n", " 7\n", - " 13.172627\n", - " -7.196406\n", - " -20.153565\n", - " 746.262383\n", - " -30.846688\n", - " 1592.815397\n", - " 1610.699379\n", - " -15.576660\n", - " 27.319692\n", - " 45.376814\n", - " 135.871422\n", - " 0.0\n", + " 8.870765\n", + " 1.099817\n", + " 1.657142\n", + " 12.921528\n", + " 0.025276\n", + " 288.846488\n", + " 438.337342\n", + " 0.996196\n", + " 2.724725\n", + " 0.220049\n", + " 10.2\n", + " 5\n", " \n", " \n", "\n", @@ -1203,43 +1211,43 @@ ], "text/plain": [ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", - "0 53.753993 -2.475239 0.404968 406.898386 1.788962 \n", - "1 241.769932 -33.905933 7.440188 1722.536053 48.034556 \n", - "2 25.344904 0.769463 -11.237007 -335.794326 -3.595284 \n", - "3 15.635557 -28.371864 -19.808469 800.088446 61.404066 \n", - "4 -0.796959 -8.546869 -4.726590 128.343028 1.083628 \n", - "5 -31.203381 -39.052177 -57.651032 1269.158981 -22.793850 \n", - "6 -120.526480 -49.314650 -67.642982 650.136816 65.155843 \n", - "7 13.172627 -7.196406 -20.153565 746.262383 -30.846688 \n", + "0 7.400000 0.080079 0.000000 65.800000 0.025529 \n", + "1 5.088797 0.112499 0.370000 0.763677 0.009000 \n", + "2 3.800000 1.100000 0.000000 0.600000 0.009000 \n", + "3 3.800000 0.080000 1.659603 0.600000 0.034734 \n", + "4 5.700000 0.330000 0.213874 10.937306 0.050000 \n", + "5 14.200000 0.080000 0.000000 0.600000 0.009000 \n", + "6 14.200000 0.087887 1.660000 0.600000 0.108297 \n", + "7 8.870765 1.099817 1.657142 12.921528 0.025276 \n", "\n", - " free sulfur dioxide total sulfur dioxide density pH \\\n", - "0 450.073724 1221.551077 5.697177 12.451377 \n", - "1 1312.264457 2141.119368 31.259074 83.654749 \n", - "2 -234.179124 382.907515 7.637684 17.748300 \n", - "3 -596.053591 -1749.797505 28.376345 -71.868790 \n", - "4 -288.352104 1184.680273 8.081500 23.012828 \n", - "5 101.490751 -661.997823 5.012738 19.615822 \n", - "6 598.106999 -3468.753037 3.750566 52.556860 \n", - "7 1592.815397 1610.699379 -15.576660 27.319692 \n", + " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", + "0 67.119106 440.000000 1.038980 2.720000 1.080000 \n", + "1 288.824821 98.000000 0.987110 3.240000 0.220000 \n", + "2 2.000000 9.000000 1.038980 3.820000 0.220000 \n", + "3 2.000000 9.000000 0.987110 3.775879 1.080000 \n", + "4 39.064968 147.790987 0.997247 3.330984 0.380000 \n", + "5 2.000000 9.000000 0.987110 2.916428 0.220055 \n", + "6 49.000000 65.466909 0.987117 2.720090 0.220006 \n", + "7 288.846488 438.337342 0.996196 2.724725 0.220049 \n", "\n", - " sulphates alcohol quality \n", - "0 14.835445 81.515696 0.0 \n", - "1 28.591759 489.172674 3.0 \n", - "2 3.380296 73.701048 1.0 \n", - "3 -14.556346 -38.315179 1.0 \n", - "4 2.168597 36.672840 0.0 \n", - "5 26.791456 -63.773678 3.0 \n", - "6 -108.310847 -91.816310 3.0 \n", - "7 45.376814 135.871422 0.0 " + " alcohol quality \n", + "0 8.0 7 \n", + "1 14.2 8 \n", + "2 8.0 4 \n", + "3 9.5 7 \n", + "4 8.7 6 \n", + "5 9.5 5 \n", + "6 9.0 6 \n", + "7 10.2 5 " ] }, - "execution_count": 51, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "plugin.model.generate(8)" + "plugin.generate(8)" ] }, { @@ -1255,7 +1263,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 13, "id": "56a1fc7e", "metadata": {}, "outputs": [ @@ -1263,35 +1271,47 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2023-03-27T18:03:45.005934+0200][38480][INFO] Step 100: MLoss: 0.9066 GLoss: 1.0013 Sum: 1.9079000000000002\n", - "[2023-03-27T18:03:51.387087+0200][38480][INFO] Step 200: MLoss: 0.4735 GLoss: 1.0112 Sum: 1.4847000000000001\n", - "[2023-03-27T18:03:59.107456+0200][38480][INFO] Step 300: MLoss: 0.4567 GLoss: 1.001 Sum: 1.4577\n", - "[2023-03-27T18:04:05.835508+0200][38480][INFO] Step 400: MLoss: 0.2715 GLoss: 0.9856 Sum: 1.2571\n", - "[2023-03-27T18:04:12.739590+0200][38480][INFO] Step 500: MLoss: 0.2193 GLoss: 0.9046 Sum: 1.1239\n", - "[2023-03-27T18:04:19.417762+0200][38480][INFO] Step 600: MLoss: 0.0143 GLoss: 0.8463 Sum: 0.8606\n", - "[2023-03-27T18:04:26.022729+0200][38480][INFO] Step 700: MLoss: 0.0048 GLoss: 0.7509 Sum: 0.7557\n", - "[2023-03-27T18:04:32.757598+0200][38480][INFO] Step 800: MLoss: 0.0083 GLoss: 0.7102 Sum: 0.7185\n", - "[2023-03-27T18:04:39.550873+0200][38480][INFO] Step 900: MLoss: 0.0029 GLoss: 0.675 Sum: 0.6779000000000001\n", - "[2023-03-27T18:04:46.573464+0200][38480][INFO] Step 1000: MLoss: 0.0039 GLoss: 0.6414 Sum: 0.6453\n", - "[2023-03-27T18:04:53.438631+0200][38480][INFO] Step 1100: MLoss: 0.003 GLoss: 0.6046 Sum: 0.6076\n", - "[2023-03-27T18:05:01.283222+0200][38480][INFO] Step 1200: MLoss: 0.0013 GLoss: 0.6297 Sum: 0.631\n", - "[2023-03-27T18:05:08.559280+0200][38480][INFO] Step 1300: MLoss: 0.0012 GLoss: 0.5479 Sum: 0.5491\n", - "[2023-03-27T18:05:15.536738+0200][38480][INFO] Step 1400: MLoss: 0.0067 GLoss: 0.5275 Sum: 0.5342\n", - "[2023-03-27T18:05:22.391711+0200][38480][INFO] Step 1500: MLoss: 0.0007 GLoss: 0.5252 Sum: 0.5259\n", - "[2023-03-27T18:05:29.285959+0200][38480][INFO] Step 1600: MLoss: 0.0018 GLoss: 0.5017 Sum: 0.5035000000000001\n", - "[2023-03-27T18:05:36.288634+0200][38480][INFO] Step 1700: MLoss: 0.0012 GLoss: 0.5013 Sum: 0.5025\n", - "[2023-03-27T18:05:43.485831+0200][38480][INFO] Step 1800: MLoss: 0.0009 GLoss: 0.4927 Sum: 0.49360000000000004\n", - "[2023-03-27T18:05:50.629387+0200][38480][INFO] Step 1900: MLoss: 0.0009 GLoss: 0.4931 Sum: 0.494\n", - "[2023-03-27T18:05:58.709478+0200][38480][INFO] Step 2000: MLoss: 0.0006 GLoss: 0.4864 Sum: 0.487\n" + "[2023-03-31T01:07:08.859587+0200][12004][INFO] Encoding fixed acidity 8821222230854998919\n", + "[2023-03-31T01:07:08.873767+0200][12004][INFO] Encoding volatile acidity 3689048099044143611\n", + "[2023-03-31T01:07:08.885765+0200][12004][INFO] Encoding citric acid 735380040632581265\n", + "[2023-03-31T01:07:08.896357+0200][12004][INFO] Encoding residual sugar 2442409671939919968\n", + "[2023-03-31T01:07:08.904579+0200][12004][INFO] Encoding chlorides 7195838597182208600\n", + "[2023-03-31T01:07:08.914577+0200][12004][INFO] Encoding free sulfur dioxide 3309873879720413309\n", + "[2023-03-31T01:07:08.922581+0200][12004][INFO] Encoding total sulfur dioxide 8059822526963442530\n", + "[2023-03-31T01:07:08.930580+0200][12004][INFO] Encoding density 3625281346475756911\n", + "[2023-03-31T01:07:08.939216+0200][12004][INFO] Encoding pH 4552002723230490789\n", + "[2023-03-31T01:07:08.947216+0200][12004][INFO] Encoding sulphates 4957484118723629481\n", + "[2023-03-31T01:07:08.956217+0200][12004][INFO] Encoding alcohol 3711001505059098944\n", + "[2023-03-31T01:07:08.964215+0200][12004][INFO] Encoding quality 3457201635469827215\n", + "[2023-03-31T01:07:17.078379+0200][12004][INFO] Step 100: MLoss: 0.9932 GLoss: 0.9775 Sum: 1.9707\n", + "[2023-03-31T01:07:24.055012+0200][12004][INFO] Step 200: MLoss: 0.2957 GLoss: 0.9254 Sum: 1.2211\n", + "[2023-03-31T01:07:32.461826+0200][12004][INFO] Step 300: MLoss: 0.0748 GLoss: 0.8407 Sum: 0.9155\n", + "[2023-03-31T01:07:39.522162+0200][12004][INFO] Step 400: MLoss: 0.0289 GLoss: 0.7444 Sum: 0.7733\n", + "[2023-03-31T01:07:47.110402+0200][12004][INFO] Step 500: MLoss: 0.0292 GLoss: 0.6655 Sum: 0.6947\n", + "[2023-03-31T01:07:54.622795+0200][12004][INFO] Step 600: MLoss: 0.0229 GLoss: 0.5844 Sum: 0.6073000000000001\n", + "[2023-03-31T01:08:01.951234+0200][12004][INFO] Step 700: MLoss: 0.0218 GLoss: 0.5572 Sum: 0.5790000000000001\n", + "[2023-03-31T01:08:09.957993+0200][12004][INFO] Step 800: MLoss: 0.0091 GLoss: 0.531 Sum: 0.5401\n", + "[2023-03-31T01:08:18.931373+0200][12004][INFO] Step 900: MLoss: 0.0114 GLoss: 0.5286 Sum: 0.5399999999999999\n", + "[2023-03-31T01:08:26.898063+0200][12004][INFO] Step 1000: MLoss: 0.0099 GLoss: 0.5259 Sum: 0.5358\n", + "[2023-03-31T01:08:34.593930+0200][12004][INFO] Step 1100: MLoss: 0.0106 GLoss: 0.5196 Sum: 0.5302\n", + "[2023-03-31T01:08:41.818482+0200][12004][INFO] Step 1200: MLoss: 0.0105 GLoss: 0.5072 Sum: 0.5176999999999999\n", + "[2023-03-31T01:08:49.426481+0200][12004][INFO] Step 1300: MLoss: 0.0086 GLoss: 0.5112 Sum: 0.5198\n", + "[2023-03-31T01:08:56.953344+0200][12004][INFO] Step 1400: MLoss: 0.0106 GLoss: 0.516 Sum: 0.5266000000000001\n", + "[2023-03-31T01:09:04.509760+0200][12004][INFO] Step 1500: MLoss: 0.0075 GLoss: 0.5062 Sum: 0.5136999999999999\n", + "[2023-03-31T01:09:11.742216+0200][12004][INFO] Step 1600: MLoss: 0.0098 GLoss: 0.5012 Sum: 0.511\n", + "[2023-03-31T01:09:19.870988+0200][12004][INFO] Step 1700: MLoss: 0.0088 GLoss: 0.499 Sum: 0.5078\n", + "[2023-03-31T01:09:27.578035+0200][12004][INFO] Step 1800: MLoss: 0.0163 GLoss: 0.4956 Sum: 0.5119\n", + "[2023-03-31T01:09:34.406045+0200][12004][INFO] Step 1900: MLoss: 0.0046 GLoss: 0.4955 Sum: 0.5001\n", + "[2023-03-31T01:09:41.645411+0200][12004][INFO] Step 2000: MLoss: 0.017 GLoss: 0.5008 Sum: 0.5178\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 43, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -1302,7 +1322,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 14, "id": "3fcb9493", "metadata": {}, "outputs": [ @@ -1312,13 +1332,13 @@ "" ] }, - "execution_count": 44, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEGCAYAAAB1iW6ZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAAsTAAALEwEAmpwYAABHmElEQVR4nO3dd3ib1dn48e/RtizvHY84e5G9EwiEsCmrQNkkjPLCWyhddMFbeFt+BQodb0vL3quUPQIEwkoCZJO9nG073tuSrXl+f0h2nOCZ2JZl35/r0mVZOtJz67F86+h+znOO0lojhBAi8hnCHYAQQojuIQldCCH6CUnoQgjRT0hCF0KIfkISuhBC9BOmcG04OTlZ5+bmhmvzQggRkdatW1eutU5p7b6wJfTc3FzWrl0brs0LIUREUkodaOs+KbkIIUQ/IQldCCH6CUnoQgjRT4Sthi6EEO3xer0UFBTQ2NgY7lDCwmazkZWVhdls7vRjJKELIfqkgoICYmJiyM3NRSkV7nB6ldaaiooKCgoKGDJkSKcfJyUXIUSf1NjYSFJS0oBL5gBKKZKSkrr87aTDhK6UylZKfa6U2qaU2qqUur2VNqcopWqUUhtCl991KQohhGjFQEzmTY7ltXem5OIDfq61Xq+UigHWKaU+0VpvO6rdcq3197ocQRflVeXx4b4PuXbstcTb4nt6c0IIETE67KFrrYu01utD1+uA7UBmTwfWloO1B3li8xMUOYvCFYIQQjR79tlnufXWW8MdBtDFGrpSKheYDKxq5e7ZSqmNSqkPlVLj2nj8TUqptUqptWVlZV2PFkiMSgSgsrHymB4vhBD9VacTulLKAbwB/ERrXXvU3euBwVrricA/gLdbew6t9eNa62la62kpKa1ORdChRJskdCFE79i/fz+jR49m0aJFjBw5kquuuoqlS5cyd+5cRowYwerVq7/T/tRTT2XChAksWLCAgwcPAvDaa69xwgknMHHiRObNmwfA1q1bmTFjBpMmTWLChAnk5eUdd7ydGraolDITTOYvaa3fPPr+lglea/2BUupfSqlkrXX5cUd4FEnoQgw8//veVrYdOrofeXzGDorl7vNaLSYcYffu3bz22ms8/fTTTJ8+nZdffpkVK1bw7rvv8sc//pELL7ywue1tt93GwoULWbhwIU8//TQ//vGPefvtt/n973/PkiVLyMzMpLq6GoBHH32U22+/nauuugqPx4Pf7z/u19SZUS4KeArYrrX+Sxtt0kPtUErNCD1vxXFH1wqH2YHZYKaisUeeXgghjjBkyBDGjx+PwWBg3LhxLFiwAKUU48ePZ//+/Ue0/eabb7jyyisBuOaaa1ixYgUAc+fOZdGiRTzxxBPNiXv27Nn88Y9/5IEHHuDAgQNERUUdd6yd6aHPBa4BNiulNoRu+y2QA6C1fhS4BLhFKeUDGoDLdQ+tPq2UItGWSGWD9NCFGCg605PuKVartfm6wWBo/t1gMODz+Tr1HI8++iirVq1i8eLFTJ06lXXr1nHllVcyc+ZMFi9ezDnnnMNjjz3GqaeeelyxdpjQtdYrgHYHRGqtHwYePq5IuiDRliglFyFEnzNnzhz+/e9/c8011/DSSy9x0kknAbBnzx5mzpzJzJkz+fDDD8nPz6empoahQ4fy4x//mIMHD7Jp06aeT+h9UVJUkpRchBB9zj/+8Q+uu+46HnzwQVJSUnjmmWcAuOOOO8jLy0NrzYIFC5g4cSIPPPAAL7zwAmazmfT0dH77298e9/ZVD1VGOjRt2jR9rAtc3LniTlYXr+aTSz7p5qiEEH3F9u3bGTNmTLjDCKvW9oFSap3Welpr7SNyLpckWxKVDZWE68NICCH6oohM6Im2RDwBD06vM9yhCCFEnxGZCV3OFhVCiO+IzITe4uSiQ/WHwhyNEEL0DRGd0N/e/TZnvnEm60rWhTkiIYQIv4hO6O/teQ+Aj/Z9FM5whBCiT4johO4JeAD49OCnBHQgnCEJIQaQRYsW8frrr4c7jO+IyIRuMVqIMccAcPGIiylrKGNT2aYwRyWEEOEVkQkdgiNdUqJS+OnUn2IymFh6YGm4QxJC9EN/+MMfGDVqFCeeeCJXXHEFDz300BH3f/rpp0yePJnx48dz/fXX43a7Afj1r3/N2LFjmTBhAr/4xS+A1qfR7U4Reeo/wFVjrsJhdhBnjWPOoDksObCEn037GQYVsZ9RQoi2fPhrKN7cvc+ZPh7Ovr/dJmvWrOGNN95g48aNeL1epkyZwtSpU5vvb2xsZNGiRXz66aeMHDmSa6+9lkceeYRrrrmGt956ix07dqCUap4yt7VpdLtTxGa/K0ZfwXnDzgPg3CHnUuwsZn3J+jBHJYToT7766isuuOACbDYbMTExnHfeeUfcv3PnToYMGcLIkSMBWLhwIcuWLSMuLg6bzcYNN9zAm2++id1uB1qfRrc7RWwPvaX5OfOxm+y8v/d9pqW3OsWBECKSddCT7mtMJhOrV6/m008/5fXXX+fhhx/ms88+a3Ua3aSkpG7bbsT20FuKMkVx2uDT+Hj/x+yt2YvH7wl3SEKIfmDu3Lm89957NDY2Ul9fz/vvv3/E/aNGjWL//v3s3r0bgBdeeIGTTz6Z+vp6ampqOOecc/jrX//Kxo0bgcPT6P7+978nJSWF/Pz8bo23X/TQAc4bdh7v7nmXC96+gCmpU3ju7OfCHZIQIsJNnz6d888/nwkTJpCWlsb48eOJi4trvt9ms/HMM89w6aWX4vP5mD59OjfffDOVlZVccMEFNDY2orXmL38JLvbW2jS63Skip89tjdaatSVreTPvTRbvXczyy5cTZ43r+IFCiD6pr0yfW19fj8PhwOVyMW/ePB5//HGmTJnSK9seENPntkYpxfT06Vw68lI0mjXFa8IdkhCiH7jpppuYNGkSU6ZM4eKLL+61ZH4s+k3Jpcn45PFEmaJYWbSS0wafFu5whBAR7uWXXw53CJ3Wb3roTcxGM1PTprK6eDVLDyzl1R2vhjskIYToFf0uoQPMypjFvpp9/PSLn3L/6vupcdeEOyQhhOhx/TKhzxk0B4ViXNI4fNrHsoJl4Q5JCCF6XL9M6CMSRvD+Re/zwjkvkBqVyqcHPw13SEII0eP6ZUIHyInNwWwwMz9nPl8VfkWDryHcIQkhIozD4Qh3CF3SbxN6kwU5C2j0N7Ly0MpwhyKEED2q3yf0yamTAdhRtSPMkQghIpXWmjvuuIMTTjiB8ePH8+qrwdFzRUVFzJs3j0mTJnHCCSewfPly/H4/ixYtam7717/+tdfi7Hfj0I9mM9lItadSUFcQ7lCEEMfogdUPsKOyeztloxNH86sZv+pU2zfffJMNGzawceNGysvLmT59OvPmzePll1/mzDPP5M4778Tv9+NyudiwYQOFhYVs2bIFoEemyW1Lv++hA2THZJNf172T4AghBo4VK1ZwxRVXYDQaSUtL4+STT2bNmjVMnz6dZ555hnvuuYfNmzcTExPD0KFD2bt3L7fddhsfffQRsbGxvRZnv++hQzChryhcEe4whBDHqLM96d42b948li1bxuLFi1m0aBE/+9nPuPbaa9m4cSNLlizh0Ucf5T//+Q9PP/10r8QzYHro5Q3luLyucIcihIhAJ510Eq+++ip+v5+ysjKWLVvGjBkzOHDgAGlpafzwhz/kxhtvZP369ZSXlxMIBLj44ou59957Wb++9xbeGTA9dICC+gJGJowMczRCiEhz0UUX8c033zBx4kSUUvzpT38iPT2d5557jgcffBCz2YzD4eD555+nsLCQ6667jkAgAMB9993Xa3F2mNCVUtnA80AaoIHHtdb/d1QbBfwfcA7gAhZprfvMenBNCT2/Ll8SuhCi0+rr64HgbK4PPvggDz744BH3L1y4kIULF37ncb3ZK2+pMz10H/BzrfV6pVQMsE4p9YnWeluLNmcDI0KXmcAjoZ99QnMPXUa6CCH6sQ5r6Frroqbetta6DtgOZB7V7ALgeR20EohXSmV0e7THKM4aR4wlRka6CCH6tS4dFFVK5QKTgVVH3ZUJtMyWBXw36aOUukkptVYptbasrKyLoR4fGbooROQJ14pqfcGxvPZOJ3SllAN4A/iJ1rq2y1sCtNaPa62naa2npaSkHMtTHDNJ6EJEFpvNRkVFxYBM6lprKioqsNlsXXpcp0a5KKXMBJP5S1rrN1tpUghkt/g9K3Rbn5HpyOTTg58S0AEMakCM1hQiomVlZVFQUEBvf5vvK2w2G1lZWV16TGdGuSjgKWC71vovbTR7F7hVKfVvggdDa7TWRV2KpIel2lPxBXxUNlaSHJUc7nCEEB0wm80MGTIk3GFElM700OcC1wCblVIbQrf9FsgB0Fo/CnxAcMjiboLDFq/r9kiPU7o9HYBSV6kkdCFEv9RhQtdarwBUB2008KPuCqonpNpTgWBCH5s0NszRCCFE9xswxeSmhF7iLAlzJEII0TMGTEJPjkrGqIyUuCShCyH6pwGT0I0GI0lRSZLQhRD91oBJ6BA8MFrqKg13GEII0SMGVEJPtadKD10I0W8NqISeFp0mPXQhRL81oBJ6qj0Vp9dJvac+3KEIIUS3G1AJPc2eBiC9dCFEvzSgEnrzWHSpowsh+qEBldCbTv+XhC6E6I8GVEJPjU5FoSis71MTQQohRLcYUAndarSSE5vD7qrd4Q5FCCG63YBK6AAj4keQV50X7jCEEKLbDbyEnjCCg7UHafA1hDsUIYToVgMyoWs0e6v3hjsUIYToVgMvocePAGBX1a4wRyKEEN1rwCX07JhsbEab1NGFEP3OgEvoRoORYfHDyKvKw+P3DMgVxYUQ/dOAS+gQrKOvKV7DtBen8cK2F8IdjhBCdIsBmdC/N/R7zBo0iwRbAquLV4c7HCGE6BYDMqHPzJjJo6c9ypxBc9heuT3c4QghRLcYkAm9yZjEMZS6SqloqAh3KEIIcdwGdkJPGgPAjsodYY5ECCGO34BO6KMSRwFI2UUI0S8M6IQea4kly5HFtopt4Q5FCCGO24BO6BAsu0jJRQjRH0hCTxxDfl0+Tq+zS4+rbKyk2FncQ1EJIUTXDfiEPjh2MAD5dfldetwDqx/g51/+vCdCEkKIYzLgE3p2TDYAB2sPdulx5Q3lFNdLD10I0XdIQg8l9K720J1eJ1XuKpkLRgjRZ3SY0JVSTyulSpVSW9q4/xSlVI1SakPo8rvuD7PnOCwOEm2Jx5TQvQGvLJQhhOgzOtNDfxY4q4M2y7XWk0KX3x9/WL0rKyaLgrqCLj3G5XUBUOWu6omQhBCiyzpM6FrrZUBlL8QSNtkx2Rys61oN3ekLjoqpdlf3QERCCNF13VVDn62U2qiU+lApNa6tRkqpm5RSa5VSa8vKyrpp08cvOyabYmcxHr+nU+211s099JrGmp4MTQghOq07Evp6YLDWeiLwD+DtthpqrR/XWk/TWk9LSUnphk13j+yYbDSawvrCTrVv8DWgCR4MlZKLEKKvOO6ErrWu1VrXh65/AJiVUsnHHVkv6upIl5YnIUnJRQjRVxx3QldKpSulVOj6jNBzRtR8tMeT0GvcUnIRQvQNnRm2+ArwDTBKKVWglLpBKXWzUurmUJNLgC1KqY3A34HLdYQNzk6yJeEwO/hg7wedStBNB0QBqhql5CKE6BtMHTXQWl/Rwf0PAw93W0RhoJTirll38T9f/Q9Xf3A1r533GjaTrc32TQdEQXroQoi+Y8CfKdrk3KHncu/ce9lfu5/N5ZvbbduU0E0GkxwUFUL0GZLQW5gzaA4AG8s2ttuuqYY+KHqQ9NCFEH2GJPQW4m3x5MbmsqlsU7vtmmrogxyDZJSLEKLPkIR+lAkpE9hUtqndSbeaSi6ZjkxJ6EKIPkMS+lEmJE+gorGCQ85DbbZpKrlkRGfQ4GvA7Xf3VnhCCNGmiEvoNQ1elmwtxh/QeHwB3lhXgM8f6Lbnn5AyAaDdsovT6yTKFEWCLQGA6sbqbtu+EEIcq4hL6J/vKOW/XljH9qJaPtxSxM9f28j7m4q67flHJIzAZrS1e2DU6XViN9kPJ3Qpuwgh+oCIS+hzhiUB8PWecr7ZEzwh9a1vOzcHS2eYDCbGJI1hW8W2Ntu4vC6izdHEW+MBSehCiL4h4hJ6aqyN4akOvt5Twcq9wYS+PK+Msrruq2OPTRrLjsod+AP+Vu93+pxHJHQ5W1QI0RdEXEKHYC/9690V7K9wcfn0bAIa3tvY9kHMrhqTOIYGXwMHag+0er/T68RutpMSFZwxstRV2m3bFkKIYxWxCd0TOhB69azBnJAZyxvru7biUHvGJI0BYFtl62WXppJLnDUOm9FGiauk27YthBDHKiIT+qyhSSgFsTYTYzJiuWxaNlsP1bK5oHvO2hwaNxSr0cr2iu2t3u/0Ook2RaOUIi06TRK6EKJPiMiEHm+3MCM3kfmjUzEaFBdMzsRmNvDy6q4tI9cWk8HEqIRRbR4YbSq5AKTZ0yhxSkIXQoRfh7Mt9lXPXT8DQ3AadmJtZr43YRDvbCjk6z3lDE9x8NSi6cf1/GOSxrB472ICOoBBHfm55/IFSy4A6dHprC5efVzbEkKI7hCRPXQAm9mIxXQ4/GtnD8btC1DT4OXLXWU0elsfodJZY5PGUu+tZ0/1niNu9wf8NPgamhN6mj2NMldZmyNiROS6b9V9/Gfnf8IdhhCdFrEJ/WgTsuLZfM8ZPHTJRHwBzabjrKfPy5qHSZl4e/fblLnKeGjNQ1Q2VuLyBedxaZnQ/dpPRWNELdIkOuHz/M/55tA34Q5DiE6L2JJLa+wWE1MGB8/eXHegihlDEo/5uZKjkpmfM5939rzD/tr9LCtYxiHnIX45/ZfBbTXV0KPTAChxlpBqTz3OVyD6ErffTYOvIdxhCNFp/aaH3iQx2sKQ5GjWH6xCa31cpZdLR15KjbuGZQXLGJM4hk8OfMJzW58DINp0uIYOyEiXfsjr9zZ/IxMiEvS7hA4wJSeB9QeqWPTMGi7851ftToXbnpkZM8mNzWVUwiheOOcFJqdO5sXtLwIQY4kBgiUXgGJncfcEL/oM6aGLSBN5JZedH8Hin8F1H0LC4FabTB2cwBvrC/hyVxkAGwtqmJQd3+VNGZSBZ856BrPBjNVo5akznmJT+Sb21+xnVsYsAOKt8ViNVumh9zNaazwBzxHrxwrR10VeD91sg9pCqG79tHyAmUODtfPLp2djNioWbzr2aQGSo5KJs8YFN200MzVtKhePvBiz0QwEF5iWsej9jyfgAZCSi4gokZfQ40O98qq2E/qwFAdf3nEK931/PPNGpPDB5uJjLrt0hpwt2v94/MGELiUXEUkiL6HHZYEytNtDBxicFDw1/9wJGRRWN/BtfnWPhZRmT6PI2X1zsovwa1qFyuV19WhnQIjuFHkJ3WiG2Kx2e+gtnTo6OJRw9b7KHgspKyaLYmdxc69ORL6mv6VGyxKDImJEXkKH4MHQDnroTeLtFhKjLRyo6LlaaE5MDhpNYX33LbQhwqvlh7PU0UWkiMyEHj+40z10gMFJdg5UOHssnOyYbADy6/J7bBuid7XslUsdXUSKyEzoCYOhvhi8nftHy02K7tEeuiT0/ueIHroMXRQRIjITetNIl+rOJdDcpGgO1TQc94RdbUm0JWI32TlY2z3T94rwa9lDl5KLiBSRmdCbTijqZB09N9mO1pBf2TP/mEopcmJzpIfejzSNQwcpuYjIEZkJvXks+v5ONR+cFJx3ZX8Pl10kofcfUnIRkajDhK6UelopVaqU2tLG/Uop9Xel1G6l1Cal1JTuD/MojjQwWjud0HOTgjMj9vSB0YL6ApkXvZ+Qg6IiEnWmh/4scFY7958NjAhdbgIeOf6wOmAwQEIu7F8Bvo7HCMfbLcRFmdnfwwndF/BR7JJJuvoDGbYoIlGHCV1rvQxo76ycC4DnddBKIF4pldFdAbbp5F9C0QZ4/Xqo6Xj8d26SvcfHooOMdOkvWib0hk6OphIi3LpjtsVMoGUWKwjd9p1z4ZVSNxHsxZOTk3N8Wx1/CTjL4aNfwY73IW08jDoLpl0PWsO6ZyFxCAyZB7GZDE4KzpHeUwbHBuv6G0o3NM/EKCKXjHIRkahXp8/VWj8OPA4wbdq0458gY9bNMOxU2PUh7FoCy/8CX/8jONdLywNZjnRuM4/knboM/LsNGHPngMly3JtvKS06jZMyT+KFbS9w5ZgribXEduvzi94lJRcRibojoRcC2S1+zwrd1jtSRgYvc28PHiRd/mfweeCUX4O7Dg5+A4XrSNu9kl+YlsGLr0LMIJhybfBx2TMhOgX2LYfk4cHa/DG6bfJt/OD9H/D81ue5dfKt3fYSRe9r6qFHm6Ol5CIiRnck9HeBW5VS/wZmAjVa6/BMPZiQC+f/48jbMiYAsHpbCT9//gvePheG7HkBvrz/cBuLAzz1YIqCk35+ONHHpHdp82OSxnDG4DN4avNT2M12Fo1bhEEdeZjiy/wvGRw7mNy43GN4gaK3eAIeDMpAjCVGeugiYnSY0JVSrwCnAMlKqQLgbsAMoLV+FPgAOAfYDbiA63oq2OORHmejBgc7E6YwZNFlwd575T7YvRSq9sGIM2D9C/D5vcEHmKPhpJ/C6PMgeWRwZE0n3D3nbjSav677K/l1+fxu1u9QSgFQ2VjJTz7/CfNz5vOXU/7SUy9VdAOP34PVaMVussuwRRExOkzoWusrOrhfAz/qtoh6SEacDYCimsbgDdaYYO891IMHYPT3oCYf6stgxV/gs3uDl/gcmHodTL4GHCntbifWEsufT/4zf1v/N57e8jRJtiRumXgLRoORJfuX4NM+1havRWvdnOhF3+P2uzEbzNhNdumhi4gReWuKHqPEaAsWo4HipoTeGqWCyTs+By5/CSr2wIGvYdOr8On/whf3wdgLYNxFkL8KbHEw58fBOdqPeBrFT6b8hDJXGY9teowl+5dw16y7eH/v+ygUVe4qdlfvZkTCiB5+1eJYNffQzXapoYuIMWASulKK9Djb4R56ZyQNC16mXAOlO2Dt07DxFdj8GhhMEPDB7k+DB2CzZx0xckYpxb0n3sv8nPn849t/cPMnN+PTPi4bdRmv7nyVNcVrJKH3YR6/B4vRQpQpilJXabjDEaJTInMul2OUEWdrv4fentTRcM6f4Oc74Oo34Rd58P0n4NAGeO48+NfM78zRblAGTh98Oi+e8yITUydiMVi4cfyNZDoyWVO85vhfkOgxbr+7uYYuJRcRKQZcQi+qPc6vz5ZoGL4A7Ikw4Qfwi11w6bPgqoBnzoFdH8NR87nEWmJ54ownWPz9xaRHpzMtbRprS9bK0mZ9WHMP3RwlJRcRMQZUQk+Pi6Kkxk0g0I2L/lodwZr6wvdB++HlS+EvY+Dt/4bKvc3NzAYz6dHBYZALchZQ7a7m/LfOZ2XRyu6LRXQbTyCY0KWHLiLJgEroGXE2PP4Ala4eWMw5YwLcvgkufQ4Gz4Vt78Kz57W6CMf8nPk8ecaTWIwWfvb5z2Qt0j6oqeQSZYqiwddAcDCXEH3bgEro6aGhi8dcR++IyQLjLoRLn4HrPwyOdX/0RHji1OC0BJ7DPb2ZGTP512n/QqP51bJf4Q14m+97ZccrvLLjlZ6JUXSKx+/BYrBgN9vxa/8RC14I0VcNqIT+nbHoPSl9PFz7Now6Gwzm4LDHh6dDwdrmJtkx2dw9+242lm3kkQ3BWYc3lG7gvlX38cDqB9hXs6/n4xStcvvdzaNcQBa5EJFhQCX09OaE3ksHuTKnwEWPwg1LYNEHYDDCM2fDF/cHZ4oEzhpyFt8f8X2e3PwkT21+iru/vptUeypWo5W/r/9778QpvqPlmaIgi1yIyDBgxqEDJEdbcVhN5JXU9/7Gc+fCTV/Aez8OnqD05Z+Cvfic2fwqewZbE7byt/V/w2Qw8ff5f2dLxRb+teFfrC9Zz5S0nl8EShyp5SgXAKe35xZHEaK7DKiEbjAoxmbEsvVQTXgCsCfCZS9C2U7Y9J/g2abrnsW+6hFeyZ5BxZx7sWXNID4mg6lpU3kz703uXXUvL579Ih/s+4A91Xuo9dQSY4lh0bhFzaNmRPdrKrkk2ZIAqGrsubn0heguAyqhA4zLjOXfq/PxBzRGQ5jmUkkZBQv+J3jd74XNr2Ne8hvSX70WTDY45TfY59zGb2b8hts/v53TXz+dWk8tUaYoYi2xlDeU4/K6+P3c34cn/gGgqeSSFBVM6OUN5WGOSIiODagaOsC4QXE0eP3sKw9D2aU1RjNMugJu3whXvArDT4Old8P9OZz6xu2cnTyFeGs8j532GKuuXMXSS5dyychLWLx3MRUNFXj93o63IbqsaRx6Uw9dErqIBAMwoQdXEtp6qDbMkRzFFhdcQu+yF+GSZ2Dy1WC28cDGT3l/1v9jTuac5tkZrxpzFZ6Ah18u+yVzXpnDg2sexOv38uTmJ9lVtSvMLyTyaa2bx6HHWmIxG8yUN0pCF33fgCu5DE91YDEZ2FJYwwWTMsMdzncpBSd8P3ipKUQ9eRo8fgpkTgVHKsRlM2Ty1ZyUeRLLC5eTHZPN89ueZ1nBMvbX7md5wXKeO/u5cL+KiNZ0ToDFYEEpRXJUMhUNFWGOSoiODbgeutloYFRaDN8erOajLUWU1/fh+VTiMuHGpXDKb4JDHqvzYf1z8NhJ/G91A4+d/Ffeu/A9Tsk6hYL6Ak7JOoX1pevZXLaZRl8jAR0I9yuISE3riVqMwdkzk6OSpeQiIsKA66FDsOzy7zX5rD1QxaI5udxz/rhwh9S2uEw45VfBC0BDFXzzL1KWP0TKgW9g+o38bcadVJlMRJmiOP210/nd17/jUP0hZg+azV9O+ct3lsET7WuaNM1qtAKQFJXEofpD4QxJiE4ZkP/p187O5cYThzAqLYYN+dXhDqdrohLg1Dvhuo+Co2U++wPGv40n+bUbiX7lSi5NmsTu6t0Mjh3Mpwc/5cE1D7KjcscRq9iL9jXtq2qnprS2UUouImIMyB762EGxjB00lnvf38bzKw/g8QWwmCLssy1nJlz7TnBM+5on4eBKcNdy675lnHvKzxmRdSJ/KFrKi9tf5MXtL5Ibm8sf5v4BgCFxQ4izxoX5BfRdTT30l1YWsmfvTnJHJFPlrsIf8GM0GMMcnRBtG5AJvcmE7Hg8K/axq6SOEzIjNMGljIJzHgxe97gwv3I5Iz9/EHiQ/0kYzAVn/5GDRvjbur9xzYfXAJARncHL575MclRy+OLuw5om4nK5FVUuD9NsyQR0gCp3lewz0acN6IQ+MSuYxDcV1HQ5ode7ffj8AeLtlo4b9xaLHa56HQrXgbMUtfjnTHztJiaOPJO5Y25iqSmA2WTjvtX3ccvSW5iUMokRCSO4ZOQlUmdvoank4vMbcXn8R5xcJAld9GUDOqHnJNqJt5vZVFDNlTNzuvTYu9/ZSn6li//cPLuHojtGJgsMDsU0aDIsewh2fkDi1rf4QWwWnP9/xJ10P3etuIuCugLqvfV8sO8DHjr5IUlWIU0lF6/PgMvjb94vMtJF9HUDulumlGJ8ZhwbC7o+t0tRTQMHKvv4hE3xOXD+3+FnO+Dyl8EaAy9ezKnv38lXlV6+HnQR9077NdsqtnHdR9fJYsghhxO6EZfHJ6f/i4gxoBM6wMSseHaV1NHo9XfcuIV6t49KpycyVrIxmmD0uXDT5zD3JxCfg3Kkob74f1zw+m086kug1FnMhW9fyD1f38PH+z+msrESgC3lW3gr760BNaa9aToFv98ULLnI6f8iQgzokgvAsNRo/AFNQVUDw1MdnX5cvduH16+pc/uItZl7MMJuZI6C0//38O8l22Dza0xZ/xzPmww8O3IOH+x9nzfy3sBisDAvax6f53+OX/tZVrCMG8ffSE5sDjGWmPC9hl7QvHi3NtHg8WM324k2R8vQRdHnDfiEnpMYXMAgv9LVpYTudPsAqHJ6IiehHy1tLKTdDZOvZtTzF3Lf2nfwAtuyJ/JG9jjezf+c05InMsbo4O/5n7H04FJsRhs3jL+BWRmziLHEMCx+WLhfRbdrSug6EOyhQ/Bs0WJncTjDEqJDAz6hZyeEEnpV15YYc7qD/+gVTg+Dk6K7Pa5elTQMblkBhzZgrshj4hcPMDF/I79JHknU3rcBOGPWD8nLmcri/M/454Z/8s8N/wRgSuoUxiaNxWQwcd6w8xiZMBKAioYKYiwxzafPR5KmchMBKw1+P4GAZnr6dN7Me5PVRauZkTEjvAEK0YYBn9BTYqxYTQYOVnQ+oWutcXoO99D7BVscDD05eJlwGax7jqgtr8O8O8BZRvbKJ8he+QSnAtuTcqiIz2Kfzc7LzmJ2Ve3C7Xfz7NZnSYlKwaAMlLhKyIjO4M8n/5nxKeOP2FRRfRF/Xvdnrhh9BVPTprKrahe5sbl9JvkvPbCU3JjhbPYHS0sNXj93TLuDdSXr+OWyX/L6+a9H/IigB1Y/QJw1jpsn3hzuUEQ3GvAJXSlFdqK9Sz10l8dP07HQiv6S0FuyxsCcW4MXgEAAsqZDIPitZMzOD6G+mBP3r+MaeyIMmkyNq4C3E1LYG5OI257IcJ/mtbLVXP3h1QyPH06WIwu72c6YxDG8vONlCusL+SL/CyalTGJV8SpyY3O5ZOQlFNQVcP6w87/zIdBbDtUfYkPZBq4ccTObAdA0VBeTnJbFn0/+M1cuvpJfLvslj5/+OCZDZP771Hvq+ffOf2M1Wlk4bmHzQtgi8kXmO7Kb5STaOVjZ+UWAm+rnAJX9MaEfzWAIzs/eZOrC4M+iTfDurVCeR1xMOgv3rAF3bXDVJV8jlxoMvDh4HJsb68l3b6dG+3h/7/vEWGJ4ZPIdPLz/PTaVb+L69BNZUrubh9Y+hFEZeXfPu9w16y4O1h1kR8UOSlwl2M12cmNzGZs0lihTFCn2FJJtyWyp2MLQuKFMSJlAjbuGElcJcZY44m3x7Knew5f5X2IymBgWP4x5WfPaTMKVjZWsLV7Lnuo9AExPmc/rbOB+85MkP7IO5tzGiAV3c9esu7jrq7v4f6v+H3dMuwO72d7Te/87dlbupLKxktmDju0ciBWHVuAL+PAFfHyZ/yVnDTmrmyMU4SIJHchOiGLNvkq01s2LSLSnrkVC7zcll2ORMQH+a9nh330e2PoW7F8Go84lrnIvP1r1GNQE+7oYLRQk5BDjLCRu523MVEbcRjOOvJ38yBpLxagzMJbt5IaoRn674rcYNAwNKNKUmQZ7Ah+Xb+WNvDe+E4YBAxcMv4Cl+z+mznfkuQEKaBpYmmxNCJ0Rqzgpax613jpcPhenZp/KkxsfpaihDIATksaRU7yVj6y/IhYXdTkLiPn6H5C3lAtm/4hdg07h+V2vs6xgGdeNuw6bycZ7e95jRMIIJqVOAoLHFgY5BqG1ZlvlNvKq8kiJSiHVnkpKVApR5ig2l23mi/wvWFa4jCmpU/jd7N+1ecZuo6+RvKo8vin6hkc2PEKAAI+d/hizMmZ1+c/2Rf4XJFgTsBgtvL/3fc7IPQOF6tR7vzO+Lf0Wp9fJzPSZmI0dDxgoc5VR0VjB6MTR3bL99ri8Lh5Y8wAXDb+o+W/Vn6jOjKNWSp0F/B9gBJ7UWt9/1P2LgAeBwtBND2utn2zvOadNm6bXrl17LDF3uyeX7+Xexdv59n9OJyG64zrupoJqzn/4KwAumZrFQ5dO7OkQI5u7HvJXwr5lULUfLA4YcjKU7QCfG4afCl8/DAVrYdAkasp3sslfx4SEUcTF5UBtIRSuIwCUGo24laLIZKTUZGKkx8OT8XEsibYzs8HNxXV11BkM1JosxPt9nF5fh01rvrFF8Z7Djk1rGgwGVkRFkagV2mDgkAqQ6vPz08oq1kbZON3pYm5DI9sDOdzu/RH33XwZU53L4fP/F4wZ+NZq4W9pmaw3BstQuUYHxYEGGvXh8xly7GnU+9xUeqrb3DUmFKNih7C1di8Lchbg8zaQV7mDOp+TS0ZczOzsU1hTvIbXdrxCtTe4bOKC+NEc8NRS7nfxwLwHyHZk81n+ZygUGY4MRsSPwOVzcaD2AAdqDwDBKYDHJo4lNy6X018/nVOzTyUxKpHntz5PtMmO1Wjm+yMvZV7WPHLjcjEqI9HmaEqcJby0/SXm58xndOJo3t39LhajhfHJ4xkWPwylFHWeOsoaytBa88mBT5oPmMdaYrlh/A1cPury73yT2VC6gUP1hxieMJxbPrmFysZKHjr5IaalT8Ov/STaEqnz1LG5bDMZjgw8fg9Or5NMRyYBHcAX8JEdm93u267eU4/DEhy5dqj+EOnR6fxpzZ94aftLxFvjeeXcV8iKyerw7VvqKsVmshFrie2wbUc8fg8rClcwyDHomD/AlFLrtNbTWr2vo4SulDICu4DTgQJgDXCF1npbizaLgGla61s7G1RfSuhLthbzXy+s491b5zIhK77D9l/vKefKJ1YBcOroVJ5eNL2HIxwgtA6u2BQIgK8BLC1GD1Xng6sClCHYRhmCl4AfveFlDux4h8Ejz0WNPgcaquHgN8H1WqfdAEYLVB+A6oOgA+CuQ5fnoTz1+GvyWVW+mZEZM0iefxcUrIGGKnbUWTl/xWA8mHnxhpmcOCI5GFfxJvC64NC3sGMxW0o30Kh9TPUbadB+ii02vI5UvvCUsMtsJlprJjW6mex2U2UwUmoyUm404jQocj1e5noh2ufmH7njeIIasrw+JrrdeIFPou3oUK95vtPFBfVOcr1ehnp9HDCZWJg5iMpjPDXwb0MvY2TAwB0H3mZEXQUVBsWKqCh0i076IJODan8jLu1DAQkmB5W+w2vxDrVn4Pa7KXRXHvHc30uaxJmDz+Q/xctZfuhrAJIMFuwGCzn2dLJSTuA/e95Bh747JZiiGWSOZXtDMRowGUzcOP5G3t35GoXtLP03Nm44Jw4+ldSoVJRSaK0xGoyMSxrHe7te54Vd/+G8nNNxGMy8sv8DcqPSONBQymmpU1lZtROrycr0pBOorQm+L2498Q8oo5m86jxmZcwiPTqd9SXr+e+lN2MxmPnJtJ8zJG4ISbYkMqIzKHYWE2WOIjkqGafXyZ7qPVQ0VJAancqG0g08vulx7FoxxpbElNzT2V69h88Ll1Hnc3HZqMu4a9Zdx/S3O96EPhu4R2t9Zuj33wBore9r0WYREZzQtx2q5Zy/L+fhKyfzvQmDOmz/ybYSfvj8WhLsZgYnRfP2j+b2QpSixzR9kLTw4eYibnlpPQCPXzOVM8alt/5Ynxu8DcFRQi2fo/pgcJI0dz001e3ddWCyBue0j4qHlDHBlai+/jvs+YyK2kISR5+HGns+GC3s+fqvlFTvYYzBTsLEq2DsBdBYG/wgK95E4/pn+aJsI6VGxWmOIURbHBRU7yfPV4PDYGFw4kiyTbEYawsoqS9ia3w6+701uH0ubqmqwQwQkwEnXAzRyZTvWcp6ZwFFfhde7WebSWHRmutcAd60wnarmdtqG0nRipVGP0uj7cQEAox3e0i1p2JqrCG6sY4TGxpRAMrAOouZtVFWiixRuLSPzVYLBWYz59U5Odfp5BO7natq60j3+XgkIY6YQICtjni+NEOqz8evK6poNCgsBivRcTkU+l0YPU4aPfW87bCTZzETaKNUNNfVwMooG36lOL+uns1WKwGl+E9hEbuiE3g6KZEdARfx/gBlRiPlpiOnRk4z2KgOeMjwuIkNBNhks7a6nVxjNAU+J76jwpjZ6CXe52GjzUqxyYQjEGCB08VZThczJ92A+cw/tv6e6kB7Cb0zNfRMIL/F7wXAzFbaXayUmkewN/9TrXX+0Q2UUjcBNwHk5HRtMqyelJ0YPMq/4WA1p4xKxWFtf7c0HRTNTrQPjIOi/V0rCcHtOzzVQdPJRa0yWYOXo8XnBC+dcdo9cNo9JB1187DcE/nOaVtNk4KmjsY24Qec1VAV/KAIbSsOGOcsB2tscKK2kKzQBa2hJj/4IRSVCI6U5jbJJ/6UM1puy10X/EZji+M35buhvgSypoHJSm59GZcXrAl+OGVMCH6bCgSg5mDwG1LxZqjYzdRh85maMycYi9+HPvA11Xs+IX7wPFRsBnML1kBCLsTn8Mv6UshfiT60geXOg4wdehrJV98ejMOedMTrofYQV69/Hl99KVUKqNqH8rpoMJjYYPSRHjuY6XN+wI68d6nze5l+/q/Ru5bgr9yDacZUJu36iL9X7oPRl8OIM6kr3cyry+4hzuNkvN/INzHx7NbVoAP8NPd8EmIy2bRvCQ0eJyV+F0X4SDfaqQh4+FZXsMCSyMS4YSRrA8VlW3H43MzKPQs1+ntoexJF298k2ZaEJSYd6kth0KTOvTe6qDM99EuAs7TWN4Z+vwaY2bI3rpRKAuq11m6l1H8Bl2mtT23veftSDx1g7v2fUVjdQLTFyPJfnUpiO7X0F1ce4K63t3DuhAyW7Sxj8/+e2YuRit7w6pqD/OqN4MHcP140vsuzcYoI5HMHP8DMoWGcgQD43Yd/b0sr3/B6Uns99M5U4AqBlkcfsjh88BMArXWF1rppteUnganHEmg4vf2judxz3licHj9r91e227a5h55gp87tw+3r2sReou87sofua6el6DdM1iOTt8HQcTKHXk3mHelMQl8DjFBKDVFKWYDLgXdbNlBKZbT49Xxge/eF2DtSYqxcPiMHs1Gx/mB1u23r3T6UgsyE4B+7yunthQhFb3J7Dyf0hvZKLkL0IR3W0LXWPqXUrcASgsMWn9Zab1VK/R5Yq7V+F/ixUup8wAdUAot6MOYeYzMbGTsojvUHq9ptV+/24bCYSA6VZSqdHtLjbL0RouglTd+6DApcXZxaWYhw6dSJRVrrD4APjrrtdy2u/wb4TfeGFh5TcuJ5ZfVBvP4AZmPrX2Ccbh/RVlPzmHU5MNr/uH0BDAocVhMut5RcRGQY8AtcHG1yTgKN3gA7iurabON0+4m2Gkl2BBN6hdPdZlsRmdy+AFaTEbvF1P4oFyH6EEnoR5mSEw/Qbtml3u3DYTWRFhsssxTVNPZGaKIXub1+rGYDdqtRSi4iYkhCP0pmfBSpMdZ2E3pTySXGZibGaqKouvMTe4nIEOyhG7BbjHJQVEQMSehHUUoxJSehUz10gIx4G4ekh97veHwBLCYDdrPpiNk1hejLJKG3YsrgePIrGyira7023jKhD4qPoqhGeuj9TXMN3WqkQUouIkJIQm/FlJwEAL5to5feVHIByIiLoqhaeuj9jdvnby65yEFRESkkobfihMy4dk8wCo5yCfXQ42xUOD00Si+uX2mqoUeZTVJDFxFDEnorbGYjYzNiW62je3wBPP4ADmtwZraM+ODZojLSpX9xe5uGLRqb148Voq+ThN6GyTkJbCqoxusPHHF70wGylj10QEa69DNuX4thi9JDFxFCEnobpucm0ugNsGpvJXkldVzyyNesP1hFfSihtzwoCshIl36medii2YTHF8Af6HhlLyHCTdYUbcOCMakkRVt4+qt9mI2KtQequO6ZNVw/dwhwOKGnSw+9Xzp8pmiwtOby+Iixdbw+phDhJD30NtjMRq6eNZjPdpSyZGsJV8zIJsps5K9Ld2E1GRiSEt3cLinaIj30fsbtDY1ysTYldCm7iL5PeujtuHrWYB75Yg82s4Ffnz2G354zhkPVjeQm27G2WK4qI94mY9H7GbcvEKyhWyShi8ghCb0dKTFW7r3wBGJsJuKigl+3R6V/92t3RlwUByqcvR2e6EFuXwCL0UiMNfj3rmmQOe9F3ycJvQM/mJ7dYZvBiXaW55URCGgMhr6zeok4dk2jXJomYCutlZKa6Pukht4Nhqc6aPQGKJQDo/1CIKDx+jVWk4G02OAC0CVtTAMhRF8iCb0bDE91ALC7tD7MkYju4Amde2A1GUlyWDEoKJMeuogAktC7QVNCzytte1EMETma1hO1mgwYDYpkh5WSWumhi75PEno3iLdbSHZYpIfeTzStJ2o1B/890mJtlNRJD130fZLQu8mwFIck9H7C7TtccgFIjbFSKj10EQEkoXeT4anBhK61nCIe6Zp76Kbgv0dqrI1S6aGLCCAJvZsMT3VQ2+ijrL79nty6A5UUy1mlfVpjixo6QFqslfJ6z3cmahOir5GE3k06M9Kl0evnqidXce/ibb0VljgGzSUXc1PJJTgWvbyDD2shwk0SejcZlR4DwMdbS9pss2pfJY3eAMvzyvFJb6/Paiq5WIyHe+iAjHQRfZ4k9G6SGmPj6lk5PPfN/jYXmP5yZxkQPI18Y0F1L0YnuuJwD/3wKBeAEhmLLvo4Sejd6FdnjSYj1saPXlrPy6sOsqO4lhrX4TlAluWVMTE7HoM6nNxF3+M+qoaeGhPsoZfK2aKij5O5XLpRjM3MP6+awv+8s4XfvrUZAKNBMX9UKqeNSWV3aT13nTsGk0Hxxa4yfnbGqFafZ+uhGrQOrm0qel/LM0WB5rNF+9N8LnWNXowGhd0iKaA/kb9mN5uck8B7t57IxoIaCqsa2FRYzZvrC1m6PVhbP3lkCg0eP3/+ZBc/emk9/z1/GOMGHU7cXn+AHz63lkZfgC/vOKXVRRVqGrzE2kwoJROB9YRGz5HDFo0GRUqMtV+VXK55ajXJDgtPLpwe7lBEN5KE3gOUUkzKjmdSdjznTsjgjjNG8eWuMopqGhme6uC6E4dQ4fTwzoZCvtpTzpu3zGFoSnCUzIdbipsXy3h82V5+fsYotNYs2VrMtkO1fL2ngrUHqvjtOaP54UlD+XhbCROy4kiLsfHOxkKm5iSSk2RvM7ZGrx+DUlhMUm1ryzd7K4i1mZpr5wBDkx18uKWYCydlMmd4chijO34FVS425FdjNCgq6t0kOazhDkl0E0novcBkNLBgTFrz7w6riXvOH8f1c4dw0b++YtEza3hq4TSGpzp4cvlehiZHM2ZQLE8s38uMIYks3lTEv9fkY1AwMi2GiVlx/PnjXRRWNfDcNwdIjLZwQmYcy3aVkZ0YxVv/PZeVeyuItpqYNSSJnSV1JEVbqHf7WPTMatJibfznv2ZjNRnQmrBM+Vvj8hIb1fe+ZTR6/Xy8tZjvTRh0xIfeg5dO4Ppn17DwmdU8d92MiE7qn24vBcAf0CzZWsKVM3PCHFHvavT6sZmNHTeMQKozZzYqpc4C/g8wAk9qre8/6n4r8DwwFagALtNa72/vOadNm6bXrl17jGH3H98erOL6Z9fg9PgZmhzNjuI67r3wBE4emcIPHvuGolBv/Ufzh/HjBSOwmoyU1jay4C9fUtfo44yxaewrd7KnrJ5Fc4bw4soDGAyHT45pyWxUxNrMVDg9nDg8mX3lTpweH2efkE5Wgp3cpGhmD0ti9b5KdpfWYTEZqHB62FPqZNuhGrIT7cwbmcKJw5OJjTJzsNLF17vLMRkVo9JjGZHqYHtRLWv2VzIsxcGk7HjGDYojymKkrM7N1kM1JEZbePvbQzzz9T5OHZXKtXNy2VdWj8cfaP7mcMbY9Oa1Wpt4/QGMSvX4h8+Hm4u45aX1vHjDTE4ccWTSrmnwcvEjX1Pl9PD+j08kIy7quLfX6PXz5493kpVg59rZgzv8gPMHNM98tY/C6gZ+ddboY0pM1zy1isKq4FTP6XE2Xv7hrGOKvTWNXj9//GA7e8rqyYq384szR5ES0/43gM93llJQ6eLqWR2//uP17cEqrnxiFaeNTeP+748n2tr7fdqmRcejLMf2oaKUWqe1ntbqfR0ldKWUEdgFnA4UAGuAK7TW21q0+W9ggtb6ZqXU5cBFWuvL2nteSeiHldW5uf/DHZTWNTJ7WBI/PGkoZqOBRq+fN9cXYjKo7yy08en2EpbtKuPOc8cS0JqS2kYGJ0XzzoZCnly+j5tPHobZqNhYUM3YjDgKq13sK3dy66kjeH1tAX9duotxg2IZmuLgs+0lONtYYs1iNJCdGMXYQXHsLatn66Ha79zv1xp/4PD7yGE1Ue/2Nf9uMxu+8wFz+tg0lueVtfrBE2szcftpI4mxmThU3cCmghq+2l2O1WRg3KA4HDYTmfFRZMTZ2FhQjclgYEJWHNUuL06PD4Vid1k9TrePBLuFBq+PaIuJU0enolRwWGJOop195U6KQ2Uwl8dPeb2b5XnlFFQ1sOq3CzC28uGxu7SeCx5egVKKyTnxTBuciNPj4/MdpQxPdTBjSCK5SdFYTQYavH4KqhpYnlfOrpI6oq0mBsXZyE2OJjfJjlKKN9cXsP5gNQAXTBrEGWPTSXJYsFuCC1RXubzsK3NiNCiqXB4+2FzU3H7q4AQunZpFlMXYfADXH9AEtCbJYSHBbuFQdQP5lS5cXj9njE0j3m5h9n2fcv3cIVhMBv75+W5uOHEIWQl2MuOjyEyIIjbKjN+vSYmxsq/cyTsbCxmVFsOwFAdf76nAbFQMCu1/szH4LS8lxsqhmgb+9NEOVu6tZGJ2PDuLa0mJsfJ/l09mYlY8RoNCa01JrZv/rM1vXrbxldX5AHx/SiZ1jT6Kaxr56ekjiLaY2FPmxGIykBFnIyfRTrzdTFmdm4OVLty+AMkOC5nxdvJK63C6/SQ7gt9W690+Vu6tYEJmPKmxVvJK6nF5fPz439/i82uqXB5yk6K589wxnDIqNThFcr2b9QeqKalt5LSxaWTGR3GouoG3NxRiNRk5cXgyFpOBaIuRxGgLDV4/ZqPhiA/VRq+fktpGNhXUsHpfJXOGJXHmuHQCWrO/wsXS7SU889U+rp2dy4/mD2/1f64jx5vQZwP3aK3PDP3+GwCt9X0t2iwJtflGKWUCioEU3c6TS0IPH601u0rqGZHqaO7xNnr9bC6s4Zs9FUzMjmfmkEQ8/gAOi+mIXnF5vZuVeyvw+gOkOGxMHZyAwQB7y5zsKqljUHwU0wYnUFbvZlN+DduLaqkLJdbJOfFUOT1kxEcxKTuewuoGdpfWMyYjhmiLCV9AU1zTyC/f2MTG/GoAlAquCHXyyBQ8fs2O4loaPH4OVrpwefxkxkfh8Qcoq3NjNCiizEb8Ac2Q5GjiosxUuTzYLUZKat2tLkBiNiq8/iPfpv918lB+c/aYNvffhvxqXlubz7oDVewsqcNkUMwYksjeMmfzN6qWMuOjmJwTT2Mowe+vcDZ/kDmsJv50yQR2Ftfxj8/yCHTwhTkrIYrbF4zAbjHx89c2tPqB2Bmv3zyb1BgbP3x+LfsqnHh8bT+PUtDZKYrMRsWDl0zkwsmZbMiv5vpn11Dp9DSXr1puJy7KTE2DlytmZBMXZeHRL/cQazOREG3hQIXrmF4XgMVkIBDQ+EI702hQzR2OKLORN/97DlVOD3e+vYV95U6UArPB0Dy6qek120xGGrwdryWb7LBiNRlwenxUtxim3PTeclhNuDy+5r/t3OFJ3Dp/BLOHJR3T6zvehH4JcJbW+sbQ79cAM7XWt7ZosyXUpiD0+55Qm/Kjnusm4CaAnJycqQcOHDimFyT6N39Ak1/pah5d0lpZwR/Q1DR4SYy2oLWmvN5Dgt2Mydj6wV6tNXvK6rGZjZiNBvaVO8lOtJMWY+VApQuH1USKw0q9x0eMtfO1/aa1RuOizGitKatzk1/lwuvX2MxG0mKtpMfajni+QEBTWufGoIJTLzclu9pGLwWVDVS7PLg8fpweHzE2E8NTYtBo7BbTEeULVyiBNHr9NHj9KBRGgwoOsaxzU+XykBEXRXZisDT0waYiGrwBRqU7mD8qtTmmQEBT7nRTWNVAYXUDTrcPg1KU1DYSZTFx8ZRM8krrOVTdwNzhyZgMikPVjRTVNBDQENDB15PisDBlcELzVAkAFaFvPduLalFKYTUZiLYaOWNsOoOT7DR6A82lh3UHKhmeGoPNbODdDYeIizIzLjMOnz9AYVUDBytdzX/zIcnR2MxGimsaKaxuYHiqg7goM8U1jazcW4HZZODU0alszK+m2uVl3KBYbGYjQ1OiGZwUDQTLeO9vOsS+Miduf4CMWBsnZMaRGG3hwy3F1DR4SbBb+N6EjFB8VWg09W4/lfXBjkKj109hdQNevybKYiA91kZarI2hKQ7GZ8bx4ZYi1u6vIt5uJjcpmonZcQxPjenUe6stfSahtyQ9dCGE6Lr2Enpnxq4VAi0LuFmh21ptEyq5xBE8OCqEEKKXdCahrwFGKKWGKKUswOXAu0e1eRdYGLp+CfBZe/VzIYQQ3a/DMTtaa59S6lZgCcFhi09rrbcqpX4PrNVavws8BbyglNoNVBJM+kIIIXpRpwZhaq0/AD446rbftbjeCFzavaEJIYToCjn/Wwgh+glJ6EII0U9IQhdCiH5CEroQQvQTnZqcq0c2rFQZcCyniiYDbZ6wFEYSV9f11dgkrq7pq3FB343teOIarLVOae2OsCX0Y6WUWtvWWVLhJHF1XV+NTeLqmr4aF/Td2HoqLim5CCFEPyEJXQgh+olITOiPhzuANkhcXddXY5O4uqavxgV9N7YeiSviauhCCCFaF4k9dCGEEK2QhC6EEP1ExCR0pdRZSqmdSqndSqlf9/K2s5VSnyultimltiqlbg/dfo9SqlAptSF0OafFY34TinWnUurMHo5vv1JqcyiGtaHbEpVSnyil8kI/E0K3K6XU30OxbVJKTemhmEa12C8blFK1SqmfhGOfKaWeVkqVhhZiabqty/tHKbUw1D5PKbWwtW11U2wPKqV2hLb/llIqPnR7rlKqocW+e7TFY6aG3gO7Q/Ef12rLbcTV5b9dd//fthHXqy1i2q+U2hC6vTf3V1s5onffZ1rrPn8hOG3vHmAoYAE2AmN7cfsZwJTQ9RiCi2aPBe4BftFK+7GhGK3AkFDsxh6Mbz+QfNRtfwJ+Hbr+a+CB0PVzgA8BBcwCVvXS368YGByOfQbMA6YAW451/wCJwN7Qz4TQ9YQeiu0MwBS6/kCL2HJbtjvqeVaH4lWh+M/ugbi69Lfrif/b1uI66v4/A78Lw/5qK0f06vssUnroM4DdWuu9WmsP8G/ggt7auNa6SGu9PnS9DtgOZLbzkAuAf2ut3VrrfcBugq+hN10APBe6/hxwYYvbn9dBK4F4pVRGD8eyANijtW7vzOAe22da62UE5+k/entd2T9nAp9orSu11lXAJ8BZPRGb1vpjrbUv9OtKgquEtSkUX6zWeqUOZoXnW7yebourHW397br9/7a9uEK97B8Ar7T3HD20v9rKEb36PouUhJ4J5Lf4vYD2E2qPUUrlApOBVaGbbg19ZXq66esUvR+vBj5WSq1TwYW4AdK01kWh68VAWphig+CCJy3/yfrCPuvq/gnXe/B6gj25JkOUUt8qpb5USp0Uui0zFE9vxNaVv11v77OTgBKtdV6L23p9fx2VI3r1fRYpCb1PUEo5gDeAn2ita4FHgGHAJKCI4Ne9cDhRaz0FOBv4kVJqXss7Q72QsIxPVcFlC88HXgvd1Ff2WbNw7p/2KKXuBHzAS6GbioAcrfVk4GfAy0qp2F4Mqc/97Y5yBUd2HHp9f7WSI5r1xvssUhJ6Zxaq7lFKKTPBP9RLWus3AbTWJVprv9Y6ADzB4RJBr8artS4M/SwF3grFUdJUSgn9LA1HbAQ/ZNZrrUtCMfaJfUbX90+vxqeUWgR8D7gqlAgIlTQqQtfXEaxPjwzF0bIs0yOxHcPfrtf2mQouTv994NUW8fbq/motR9DL77NISeidWai6x4Rqc08B27XWf2lxe8va80VA05H3d4HLlVJWpdQQYATBgzA9EVu0Uiqm6TrBA2pbOHLh7oXAOy1iuzZ0lH0WUNPiK2FPOKLX1Bf2WYvtdWX/LAHOUEolhEoNZ4Ru63ZKqbOAXwLna61dLW5PUUoZQ9eHEtxHe0Px1SqlZoXeq9e2eD3dGVdX/3a9+X97GrBDa91cSunN/dVWjqC332fHc2S3Ny8EjwrvIvgpe2cvb/tEgl+VNgEbQpdzgBeAzaHb3wUyWjzmzlCsOznOI+gdxDaU4OiBjcDWpn0DJAGfAnnAUiAxdLsC/hmKbTMwrQdjiwYqgLgWt/X6PiP4gVIEeAnWJG84lv1DsJ69O3S5rgdj202wjtr0Xns01Pbi0N94A7AeOK/F80wjmGD3AA8TOgu8m+Pq8t+uu/9vW4srdPuzwM1Hte3N/dVWjujV95mc+i+EEP1EpJRchBBCdEASuhBC9BOS0IUQop+QhC6EEP2EJHQhhOgnJKGLAU0FZ4C0hzsOIbqDDFsUA5pSaj/BMcDl4Y5FiOMlPXQxYITOql2slNqolNqilLobGAR8rpT6PNTmDKXUN0qp9Uqp10JzczTNOf8nFZxDe7VSang4X4sQrZGELgaSs4BDWuuJWusTgL8Bh4D5Wuv5Sqlk4C7gNB2c7GwtwUmdmtRorccTPLPwb70auRCdIAldDCSbgdOVUg8opU7SWtccdf8sgosSfKWCq94sJLgoR5NXWvyc3dPBCtFVpnAHIERv0VrvUsGlvs4B7lVKfXpUE0VwcYEr2nqKNq4L0SdID10MGEqpQYBLa/0i8CDBpczqCC4ZBsHVgeY21cdDNfeRLZ7ishY/v+mdqIXoPOmhi4FkPPCgUipAcLa+WwiWTj5SSh0K1dEXAa8opayhx9xFcLZAgASl1CbATXBaYCH6FBm2KEQnyPBGEQmk5CKEEP2E9NCFEKKfkB66EEL0E5LQhRCin5CELoQQ/YQkdCGE6CckoQshRD/x/wF1R+oUEOINSAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] @@ -1335,7 +1355,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 15, "id": "2ea981cd", "metadata": {}, "outputs": [ @@ -1345,7 +1365,7 @@ "array([3, 4, 5, 6, 7, 8, 9])" ] }, - "execution_count": 45, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -1357,542 +1377,187 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 16, "id": "bbd33233", "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2023-03-27T18:05:59.734678+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:05:59.737612+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:00.157952+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:00.160095+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:00.484737+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:00.485757+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:00.786487+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:00.788466+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:01.100020+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:01.102261+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:01.460078+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:01.462163+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:01.805568+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:01.807568+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:02.183897+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:02.185904+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:02.569835+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:02.571874+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:03.033272+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:03.035146+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:03.579187+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:03.582312+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:04.128201+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:04.131216+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:04.909594+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:04.912681+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:05.506491+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:05.509890+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:06.285555+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:06.287092+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:06.748144+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:06.751143+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:07.239364+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:07.241364+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:07.833861+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:07.835862+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:08.270020+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:08.273103+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:08.579762+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:08.581664+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:08.995746+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:08.996750+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:09.387130+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:09.389133+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:09.913255+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:09.915271+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:10.414403+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:10.417511+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:10.986099+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:10.988092+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:11.384006+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 0. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:11.699391+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:11.700392+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:12.138923+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:12.140923+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:12.604077+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:12.606674+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:12.997333+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:12.999663+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:13.547570+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:13.550541+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:13.954516+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:13.956516+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:14.445116+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:14.452112+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:14.829066+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:14.832071+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:15.312829+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:15.315831+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:15.757355+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:15.759926+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:16.143136+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:16.145134+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:16.560027+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:16.562025+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:16.861918+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:16.863918+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:17.183558+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 0. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:17.637582+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:17.640281+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:17.997687+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:17.998689+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:18.345381+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:18.347383+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:18.676026+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:18.678607+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:19.007549+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:19.010506+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:19.346424+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:19.348531+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:19.696186+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:19.697186+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:20.073809+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:20.077249+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", - "[2023-03-27T18:06:20.472414+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:20.475399+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:20.942265+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:20.944268+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:21.302342+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:21.304314+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:21.665401+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:21.666987+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:22.067854+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:22.069371+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:22.392718+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:22.395677+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:22.716515+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:22.717515+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:23.047434+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:23.049434+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:23.399152+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:23.401151+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:23.745625+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:23.747624+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:24.098540+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:24.099540+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:24.421854+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:24.422839+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:24.738758+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:24.739667+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:25.058648+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:25.060550+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:25.399681+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:25.401599+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:25.737806+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:25.738793+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:26.069784+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:26.071290+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:26.416549+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:26.418554+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:26.801542+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:26.803529+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:27.139240+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:27.141225+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:27.488070+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:27.490052+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:27.823788+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:27.824814+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:28.163857+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:28.166838+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:28.499341+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:28.501342+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:28.823408+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:28.825499+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:29.125222+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:29.128129+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:29.492914+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:29.496428+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:29.833079+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:29.835167+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:30.217776+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:30.219777+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:30.536676+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:30.538667+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:30.861816+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:30.863812+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:31.177127+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:31.180126+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:31.606751+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:31.607978+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:31.949768+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:31.951785+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:32.289786+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:32.291671+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:32.629730+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 0. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:32.942556+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:32.945560+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 1. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:32.949202+0200][38480][INFO] [residual sugar] quality loss for constraints ge = 0.6. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:33.286281+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:33.287280+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:33.620445+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:33.622445+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:33.945427+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:33.947494+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:34.298877+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:34.300955+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:34.618880+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:34.620789+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:34.959467+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:34.961383+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:35.296247+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:35.298303+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:35.763113+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:35.765112+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:36.178981+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:36.181338+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:36.555008+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:36.555991+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:36.880093+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:36.881104+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:37.299044+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:37.301205+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:37.708557+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:37.711544+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:38.087165+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:38.089166+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:38.482563+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:38.483562+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:38.941184+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:38.942166+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:39.291995+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:39.294883+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:39.642425+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:39.645485+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:39.965926+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:39.967445+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:40.280863+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:40.281866+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:40.567363+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:40.569362+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:40.863820+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:40.865893+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:41.406311+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:41.409435+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:42.003319+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:42.006307+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:42.470804+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:42.471786+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:42.768361+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:42.770360+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:43.102405+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:43.105718+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", - "[2023-03-27T18:06:43.426329+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:43.429478+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:43.757004+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:43.759124+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:44.083407+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:44.084408+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:44.400443+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:44.401428+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:44.706402+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:44.708999+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:45.018534+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:45.019535+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:45.519397+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:45.521407+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:45.921477+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:45.922985+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:46.265432+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:46.267956+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:46.717722+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:46.719733+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:47.062693+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:47.064691+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:47.417125+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:47.418108+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:47.758309+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:47.760595+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:48.135817+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:48.137801+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:48.458595+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:48.460608+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:48.754069+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:48.756024+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:49.049862+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:49.051462+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:49.350537+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:49.352536+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:49.766318+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:49.769390+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:50.276306+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:50.279351+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:50.665664+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:50.666685+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:51.009462+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:51.012707+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:51.308313+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:51.309313+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:51.637138+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:51.639120+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:51.979944+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:51.980946+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:52.297063+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:52.298062+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:52.625280+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:52.628303+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:52.938341+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:52.939345+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:53.233624+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:53.235624+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:53.550284+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:53.552284+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:53.859100+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:53.863229+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:54.227895+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:54.229895+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:54.534473+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:54.536457+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:06:54.835486+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:54.837487+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:55.132594+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:55.134593+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:55.465635+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:55.467185+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:55.807745+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:55.810517+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:56.300336+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:56.302923+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:56.604424+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:56.605423+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:56.898530+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:56.900544+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:57.205520+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:57.206520+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:57.503438+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:57.505437+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:57.819558+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:57.821581+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", - "[2023-03-27T18:06:58.160813+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:58.163417+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:58.462315+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:58.463303+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:06:58.815614+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:58.817596+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:06:59.129940+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:59.130934+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:06:59.577632+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:59.580621+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:06:59.909210+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:06:59.910211+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:00.263906+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:00.265906+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:00.573175+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:00.574177+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:00.866210+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:00.868793+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:01.205344+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:01.207327+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:01.606906+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:01.608906+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:02.102300+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:02.105211+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:07:02.503969+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:02.506485+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:02.906864+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:02.908864+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:03.298141+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:03.300142+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:03.619687+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:03.621670+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:03.942307+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:03.946964+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:04.383317+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:04.384318+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:04.685032+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:04.687584+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:04.985829+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:04.986829+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", - "[2023-03-27T18:07:05.266858+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:05.269157+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:05.580166+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:05.582149+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:05.889785+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:05.892186+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:06.211209+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:06.213722+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:06.513714+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:06.515729+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:06.832167+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:06.834177+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:07.144798+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:07.146797+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:07.479304+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:07.481835+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:07.846999+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:07.848997+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:07:08.195789+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:08.197813+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 1. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:08.200813+0200][38480][INFO] [residual sugar] quality loss for constraints le = 65.8. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:07:08.691113+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:08.694249+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:09.231893+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:09.235438+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:09.713446+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:09.716162+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", - "[2023-03-27T18:07:10.805837+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:10.809012+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:07:11.446846+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:11.450600+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:12.110297+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:12.114136+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:12.587219+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:12.589217+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:13.186604+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:13.188628+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:13.765722+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:13.767730+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:14.222493+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:14.225273+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:14.581621+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:14.582622+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:14.916005+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:14.917005+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:15.232768+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:15.233771+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:15.587426+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:15.589426+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:07:15.937914+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:15.939914+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:16.341209+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:16.343228+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:16.667291+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:16.669292+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:16.989838+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:16.991912+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:17.306825+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:17.308797+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:17.659105+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:17.661131+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:18.018946+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:18.019947+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:18.393086+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:18.396311+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", - "[2023-03-27T18:07:18.830421+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:18.833527+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:19.232926+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:19.236012+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:19.669845+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:19.672139+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:20.034654+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:20.035654+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:20.365288+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:20.367291+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:20.677852+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:20.680692+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:20.988636+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:20.990732+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:21.326922+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:21.329905+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:07:21.682149+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:21.684150+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:22.042272+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 1. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:22.043272+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:07:22.417916+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:22.418916+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:22.749237+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:22.751237+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:23.090475+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:23.091459+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:07:23.470508+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:23.473305+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:07:23.821072+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:23.823567+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 1. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:23.827191+0200][38480][INFO] [residual sugar] quality loss for constraints ge = 0.6. Remaining 0. prev length 1. Original dtype float64.\n", - "[2023-03-27T18:07:24.193607+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:24.194590+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:24.532529+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:24.534525+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:24.876586+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:24.878585+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:07:25.216076+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:25.217076+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:25.599528+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:25.601333+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:26.159795+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:26.161982+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:26.541276+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:26.542274+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:07:26.869887+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:26.872038+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:07:27.183814+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:27.186139+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:27.522592+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:27.524574+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:07:27.885528+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:27.886547+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n", - "[2023-03-27T18:07:28.236311+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 6. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:28.237310+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 6. Original dtype float64.\n", - "[2023-03-27T18:07:28.569622+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:28.571622+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:28.889372+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:28.890372+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:29.200272+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:29.202272+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:29.533137+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 4. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:29.535216+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 4. Original dtype float64.\n", - "[2023-03-27T18:07:29.936280+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:29.939026+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:30.369796+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 2. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:30.371797+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 2. Original dtype float64.\n", - "[2023-03-27T18:07:30.718054+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 3. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:30.720128+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 3. Original dtype float64.\n", - "[2023-03-27T18:07:31.139806+0200][38480][INFO] [alcohol] quality loss for constraints le = 14.2. Remaining 5. prev length 7. Original dtype float64.\n", - "[2023-03-27T18:07:31.140809+0200][38480][INFO] [alcohol] quality loss for constraints ge = 8.0. Remaining 0. prev length 5. Original dtype float64.\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mplugin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moutcome\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\plugin.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, count, constraints, random_state, **kwargs)\u001b[0m\n\u001b[0;32m 337\u001b[0m \u001b[0msyn_schema\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mSchema\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_constraints\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgen_constraints\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 338\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 339\u001b[1;33m \u001b[0mX_syn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_generate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msyn_schema\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msyn_schema\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 340\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 341\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mX_syn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_tabular\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_ddpm.py\u001b[0m in \u001b[0;36m_generate\u001b[1;34m(self, count, syn_schema, **kwargs)\u001b[0m\n\u001b[0;32m 246\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 247\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 248\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_safe_generate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msyn_schema\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 249\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 250\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\plugin.py\u001b[0m in \u001b[0;36m_safe_generate\u001b[1;34m(self, gen_cbk, count, syn_schema, **kwargs)\u001b[0m\n\u001b[0;32m 391\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mit\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msampling_patience\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 392\u001b[0m \u001b[1;31m# sample\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 393\u001b[1;33m \u001b[0miter_samples\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgen_cbk\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 394\u001b[0m iter_samples_df = pd.DataFrame(\n\u001b[0;32m 395\u001b[0m \u001b[0miter_samples\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining_schema\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_ddpm.py\u001b[0m in \u001b[0;36mcallback\u001b[1;34m(count)\u001b[0m\n\u001b[0;32m 241\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 242\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# type: ignore\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 243\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcond\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 244\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_classification\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 245\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_iloc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\__init__.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, count, cond)\u001b[0m\n\u001b[0;32m 211\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcond\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 212\u001b[0m \u001b[0mcond\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcond\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 213\u001b[1;33m \u001b[0msample\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdiffusion\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 214\u001b[0m \u001b[0msample\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msample\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_col_perm\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0msample\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\gaussian_multinomial_diffsuion.py\u001b[0m in \u001b[0;36msample_all\u001b[1;34m(self, num_samples, cond, max_batch_size, ddim)\u001b[0m\n\u001b[0;32m 951\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 952\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mb\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mbs\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 953\u001b[1;33m \u001b[0msample\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msample_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 954\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0many\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msample\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 955\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"found NaNs in sample\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\autograd\\grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 27\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 28\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mF\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 29\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\gaussian_multinomial_diffsuion.py\u001b[0m in \u001b[0;36msample\u001b[1;34m(self, num_samples, cond)\u001b[0m\n\u001b[0;32m 918\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"Sample timestep {i:4d}\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"\\r\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 919\u001b[0m \u001b[0mt\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfull\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 920\u001b[1;33m model_out = self.denoise_fn(\n\u001b[0m\u001b[0;32m 921\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mz_norm\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlog_z\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcond\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 922\u001b[0m )\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\tabular_ddpm\\modules.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, t, y)\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[0memb\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0memb_nonlin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlabel_emb\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0memb\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 113\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\mlp.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m 398\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mvalidate_arguments\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marbitrary_types_allowed\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 399\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 400\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 401\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 402\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_train_epoch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloader\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mfloat\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 202\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 203\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 204\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 205\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 206\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.validate_arguments.validate.wrapper_function\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.call\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\pydantic\\decorator.cp39-win_amd64.pyd\u001b[0m in \u001b[0;36mpydantic.decorator.ValidatedFunction.execute\u001b[1;34m()\u001b[0m\n", - "\u001b[1;32md:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\core\\models\\mlp.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mvalidate_arguments\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marbitrary_types_allowed\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 114\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 116\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 202\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 203\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 204\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 205\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 206\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32md:\\DevTools\\Miniconda\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 114\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 116\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mKeyboardInterrupt\u001b[0m: " - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
03.81.100.0065.80.009000289.050.1049971.0388933.820.2200008.05
114.20.081.660.60.251377289.09.0000000.9872913.821.0800008.06
23.81.100.000.60.0090002.09.0000000.9871103.821.08000014.27
33.80.081.660.60.0090002.09.0000000.9871103.821.08000014.26
43.81.101.660.60.009000289.0440.0000000.9871103.821.08000014.26
53.81.100.000.60.0090002.09.0000000.9871103.821.0799758.06
614.20.081.660.60.009000289.09.0000000.9871103.821.08000014.27
\n", + "
" + ], + "text/plain": [ + " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", + "0 3.8 1.10 0.00 65.8 0.009000 \n", + "1 14.2 0.08 1.66 0.6 0.251377 \n", + "2 3.8 1.10 0.00 0.6 0.009000 \n", + "3 3.8 0.08 1.66 0.6 0.009000 \n", + "4 3.8 1.10 1.66 0.6 0.009000 \n", + "5 3.8 1.10 0.00 0.6 0.009000 \n", + "6 14.2 0.08 1.66 0.6 0.009000 \n", + "\n", + " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", + "0 289.0 50.104997 1.038893 3.82 0.220000 \n", + "1 289.0 9.000000 0.987291 3.82 1.080000 \n", + "2 2.0 9.000000 0.987110 3.82 1.080000 \n", + "3 2.0 9.000000 0.987110 3.82 1.080000 \n", + "4 289.0 440.000000 0.987110 3.82 1.080000 \n", + "5 2.0 9.000000 0.987110 3.82 1.079975 \n", + "6 289.0 9.000000 0.987110 3.82 1.080000 \n", + "\n", + " alcohol quality \n", + "0 8.0 5 \n", + "1 8.0 6 \n", + "2 14.2 7 \n", + "3 14.2 6 \n", + "4 14.2 6 \n", + "5 8.0 6 \n", + "6 14.2 7 " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ From 428177b34d48ea816aaba47937d298528caa530d Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 31 Mar 2023 11:58:19 +0200 Subject: [PATCH 34/95] debug LogDistribution and LogIntDistribution --- src/synthcity/plugins/core/distribution.py | 46 +++++++++----------- src/synthcity/plugins/generic/plugin_ddpm.py | 2 +- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/src/synthcity/plugins/core/distribution.py b/src/synthcity/plugins/core/distribution.py index 3aff6f74..96ce24db 100644 --- a/src/synthcity/plugins/core/distribution.py +++ b/src/synthcity/plugins/core/distribution.py @@ -157,7 +157,7 @@ def sample(self, count: int = 1) -> Any: if msamples is not None: return msamples - return np.random.choice(self.choices, count).tolist() + return np.random.choice(self.choices, count) def has(self, val: Any) -> bool: return val in self.choices @@ -209,8 +209,8 @@ class FloatDistribution(Distribution): :parts: 1 """ - low: float = np.iinfo(np.int64).min - high: float = np.iinfo(np.int64).max + low: float = np.finfo(np.float64).min + high: float = np.finfo(np.float64).max @validator("low", always=True) def _validate_low_thresh(cls: Any, v: float, values: Dict) -> float: @@ -274,18 +274,18 @@ def dtype(self) -> str: class LogDistribution(FloatDistribution): - low: float = np.iinfo(np.int64).min - high: float = np.iinfo(np.int64).max - base: float = 10.0 - _log_low: float = np.log(low) / np.log(base) - _log_high: float = np.log(high) / np.log(base) + low: float = np.finfo(np.float64).tiny + high: float = np.finfo(np.float64).max + base: float = 2.0 def sample(self, count: int = 1) -> Any: np.random.seed(self.random_state) msamples = self.sample_marginal(count) if msamples is not None: return msamples - return self.base ** np.random.uniform(self._log_low, self._log_high, count) + lo = np.log2(self.low) / np.log2(self.base) + hi = np.log2(self.high) / np.log2(self.base) + return self.base ** np.random.uniform(lo, hi, count) class IntegerDistribution(Distribution): @@ -322,8 +322,9 @@ def sample(self, count: int = 1) -> Any: if msamples is not None: return msamples - choices = [val for val in range(self.low, self.high + 1, self.step)] - return np.random.choice(choices, count).tolist() + high = (self.high + 1 - self.low) // self.step + s = np.random.choice(high, count) + return s * self.step + self.low def has(self, val: Any) -> bool: return self.low <= val and val <= self.high @@ -361,18 +362,18 @@ def dtype(self) -> str: class LogIntDistribution(FloatDistribution): - low: int = np.iinfo(np.int64).min - high: int = np.iinfo(np.int64).max - base: float = 10.0 - _log_low: float = np.log(low) / np.log(base) - _log_high: float = np.log(high) / np.log(base) + low: float = 1.0 + high: float = float(np.iinfo(np.int64).max) + base: float = 2.0 def sample(self, count: int = 1) -> Any: np.random.seed(self.random_state) msamples = self.sample_marginal(count) if msamples is not None: return msamples - s = self.base ** np.random.uniform(self._log_low, self._log_high, count) + lo = np.log2(self.low) / np.log2(self.base) + hi = np.log2(self.high) / np.log2(self.base) + s = self.base ** np.random.uniform(lo, hi, count) return s.astype(int) @@ -411,15 +412,8 @@ def sample(self, count: int = 1) -> Any: if msamples is not None: return msamples - samples = np.random.uniform( - datetime.timestamp(self.low), datetime.timestamp(self.high), count - ) - - samples_dt = [] - for s in samples: - samples_dt.append(datetime.fromtimestamp(s)) - - return samples_dt + delta = self.high - self.low + return self.low + delta * np.random.rand(count) def has(self, val: datetime) -> bool: return self.low <= val and val <= self.high diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 9ac18878..8eda1ea9 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -189,7 +189,7 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]: return [ LogDistribution(name="lr", low=1e-5, high=1e-1), LogIntDistribution(name="batch_size", low=256, high=4096), - IntegerDistribution(name="num_timesteps", choices=[100, 1000]), + IntegerDistribution(name="num_timesteps", low=10, high=1000), LogIntDistribution(name="n_iter", low=1000, high=10000), IntegerDistribution(name="n_layers_hidden", low=2, high=8), LogIntDistribution(name="dim_hidden", low=128, high=1024), From 4705319e082c98e8f6405e479bd8a582e25318bc Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sat, 1 Apr 2023 21:57:01 +0200 Subject: [PATCH 35/95] change discrete encoding of BinEncoder to passthrough; passed all tests in test_tabular_encoder --- .gitignore | 1 + src/synthcity/plugins/core/models/tabular_encoder.py | 12 +++++------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index b2bc0daa..c41e0784 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,4 @@ generated MNIST cifar-10* src/test.py +.tmp.py diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index e5f260e3..74b72142 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -83,7 +83,7 @@ def __init__( """ self.whitelist = whitelist self.categorical_limit = categorical_limit - self.max_clusters = max_clusters # for compatibility + self.max_clusters = max_clusters if categorical_encoder is not None: self.categorical_encoder = categorical_encoder if continuous_encoder is not None: @@ -150,10 +150,7 @@ def fit( continue column_hash = dataframe_hash(raw_data[[name]]) log.info(f"Encoding {name} {column_hash}") - if name in discrete_columns: - ftype = "discrete" - else: - ftype = "continuous" + ftype = "discrete" if name in discrete_columns else "continuous" column_transform_info = self._fit_feature(raw_data[name], ftype) self.output_dimensions += column_transform_info.output_dimensions @@ -289,8 +286,9 @@ class BinEncoder(TabularEncoder): continuous_encoder = "bayesian_gmm" cont_encoder_params = dict(n_components=2) - categorical_encoder = "onehot" - cat_encoder_params = dict(handle_unknown="ignore", sparse=False) + categorical_encoder = "passthrough" # "onehot" + # ! onehot encoder does not pass the tests + cat_encoder_params = dict() # dict(handle_unknown="ignore", sparse=False) # TODO: check if this is correct def _transform_feature( From d9d73f14026d87753e153afe3140c94d7175b8cc Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 2 Apr 2023 20:36:29 +0200 Subject: [PATCH 36/95] add tabnet to plugins/core/models --- src/synthcity/plugins/core/models/tabnet.py | 1030 +++++++++++++++++++ 1 file changed, 1030 insertions(+) create mode 100644 src/synthcity/plugins/core/models/tabnet.py diff --git a/src/synthcity/plugins/core/models/tabnet.py b/src/synthcity/plugins/core/models/tabnet.py new file mode 100644 index 00000000..25383cb9 --- /dev/null +++ b/src/synthcity/plugins/core/models/tabnet.py @@ -0,0 +1,1030 @@ +# third party +import numpy as np +import torch +from torch.autograd import Function +from torch.nn import BatchNorm1d, Linear, ReLU + + +# credits to Yandex https://github.com/Qwicen/node/blob/master/lib/nn_utils.py +def _make_ix_like(input, dim=0): + d = input.size(dim) + rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) + view = [1] * input.dim() + view[0] = -1 + return rho.view(view).transpose(0, dim) + + +class SparsemaxFunction(Function): + """ + An implementation of sparsemax (Martins & Astudillo, 2016). See + :cite:`DBLP:journals/corr/MartinsA16` for detailed description. + By Ben Peters and Vlad Niculae + """ + + @staticmethod + def forward(ctx, input, dim=-1): + """sparsemax: normalizing sparse transform (a la softmax) + + Parameters + ---------- + ctx : torch.autograd.function._ContextMethodMixin + input : torch.Tensor + any shape + dim : int + dimension along which to apply sparsemax + + Returns + ------- + output : torch.Tensor + same shape as input + + """ + ctx.dim = dim + max_val, _ = input.max(dim=dim, keepdim=True) + input -= max_val # same numerical stability trick as for softmax + tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) + output = torch.clamp(input - tau, min=0) + ctx.save_for_backward(supp_size, output) + return output + + @staticmethod + def backward(ctx, grad_output): + supp_size, output = ctx.saved_tensors + dim = ctx.dim + grad_input = grad_output.clone() + grad_input[output == 0] = 0 + + v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() + v_hat = v_hat.unsqueeze(dim) + grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) + return grad_input, None + + @staticmethod + def _threshold_and_support(input, dim=-1): + """Sparsemax building block: compute the threshold + + Parameters + ---------- + input: torch.Tensor + any dimension + dim : int + dimension along which to apply the sparsemax + + Returns + ------- + tau : torch.Tensor + the threshold value + support_size : torch.Tensor + + """ + + input_srt, _ = torch.sort(input, descending=True, dim=dim) + input_cumsum = input_srt.cumsum(dim) - 1 + rhos = _make_ix_like(input, dim) + support = rhos * input_srt > input_cumsum + + support_size = support.sum(dim=dim).unsqueeze(dim) + tau = input_cumsum.gather(dim, support_size - 1) + tau /= support_size.to(input.dtype) + return tau, support_size + + +sparsemax = SparsemaxFunction.apply + + +def initialize_non_glu(module, input_dim, output_dim): + gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(4 * input_dim)) + torch.nn.init.xavier_normal_(module.weight, gain=gain_value) + # torch.nn.init.zeros_(module.bias) + return + + +def initialize_glu(module, input_dim, output_dim): + gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(input_dim)) + torch.nn.init.xavier_normal_(module.weight, gain=gain_value) + # torch.nn.init.zeros_(module.bias) + return + + +class GBN(torch.nn.Module): + """ + Ghost Batch Normalization + https://arxiv.org/abs/1705.08741 + """ + + def __init__(self, input_dim, virtual_batch_size=128, momentum=0.01): + super(GBN, self).__init__() + + self.input_dim = input_dim + self.virtual_batch_size = virtual_batch_size + self.bn = BatchNorm1d(self.input_dim, momentum=momentum) + + def forward(self, x): + chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0) + res = [self.bn(x_) for x_ in chunks] + + return torch.cat(res, dim=0) + + +class TabNetEncoder(torch.nn.Module): + def __init__( + self, + input_dim, + output_dim, + n_d=8, + n_a=8, + n_steps=3, + gamma=1.3, + n_independent=2, + n_shared=2, + epsilon=1e-15, + virtual_batch_size=128, + momentum=0.02, + mask_type="sparsemax", + group_attention_matrix=None, + ): + """ + Defines main part of the TabNet network without the embedding layers. + + Parameters + ---------- + input_dim : int + Number of features + output_dim : int or list of int for multi task classification + Dimension of network output + examples : one for regression, 2 for binary classification etc... + n_d : int + Dimension of the prediction layer (usually between 4 and 64) + n_a : int + Dimension of the attention layer (usually between 4 and 64) + n_steps : int + Number of successive steps in the network (usually between 3 and 10) + gamma : float + Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) + n_independent : int + Number of independent GLU layer in each GLU block (default 2) + n_shared : int + Number of independent GLU layer in each GLU block (default 2) + epsilon : float + Avoid log(0), this should be kept very low + virtual_batch_size : int + Batch size for Ghost Batch Normalization + momentum : float + Float value between 0 and 1 which will be used for momentum in all batch norm + mask_type : str + Either "sparsemax" or "entmax" : this is the masking function to use + group_attention_matrix : torch matrix + Matrix of size (n_groups, input_dim), m_ij = importance within group i of feature j + """ + super(TabNetEncoder, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.is_multi_task = isinstance(output_dim, list) + self.n_d = n_d + self.n_a = n_a + self.n_steps = n_steps + self.gamma = gamma + self.epsilon = epsilon + self.n_independent = n_independent + self.n_shared = n_shared + self.virtual_batch_size = virtual_batch_size + self.mask_type = mask_type + self.initial_bn = BatchNorm1d(self.input_dim, momentum=0.01) + self.group_attention_matrix = group_attention_matrix + + if self.group_attention_matrix is None: + # no groups + self.group_attention_matrix = torch.eye(self.input_dim) + self.attention_dim = self.input_dim + else: + self.attention_dim = self.group_attention_matrix.shape[0] + + if self.n_shared > 0: + shared_feat_transform = torch.nn.ModuleList() + for i in range(self.n_shared): + if i == 0: + shared_feat_transform.append( + Linear(self.input_dim, 2 * (n_d + n_a), bias=False) + ) + else: + shared_feat_transform.append( + Linear(n_d + n_a, 2 * (n_d + n_a), bias=False) + ) + + else: + shared_feat_transform = None + + self.initial_splitter = FeatTransformer( + self.input_dim, + n_d + n_a, + shared_feat_transform, + n_glu_independent=self.n_independent, + virtual_batch_size=self.virtual_batch_size, + momentum=momentum, + ) + + self.feat_transformers = torch.nn.ModuleList() + self.att_transformers = torch.nn.ModuleList() + + for step in range(n_steps): + transformer = FeatTransformer( + self.input_dim, + n_d + n_a, + shared_feat_transform, + n_glu_independent=self.n_independent, + virtual_batch_size=self.virtual_batch_size, + momentum=momentum, + ) + attention = AttentiveTransformer( + n_a, + self.attention_dim, + group_matrix=group_attention_matrix, + virtual_batch_size=self.virtual_batch_size, + momentum=momentum, + mask_type=self.mask_type, + ) + self.feat_transformers.append(transformer) + self.att_transformers.append(attention) + + def forward(self, x, prior=None): + x = self.initial_bn(x) + + bs = x.shape[0] # batch size + if prior is None: + prior = torch.ones((bs, self.attention_dim)).to(x.device) + + M_loss = 0 + att = self.initial_splitter(x)[:, self.n_d :] + steps_output = [] + for step in range(self.n_steps): + M = self.att_transformers[step](prior, att) + M_loss += torch.mean( + torch.sum(torch.mul(M, torch.log(M + self.epsilon)), dim=1) + ) + # update prior + prior = torch.mul(self.gamma - M, prior) + # output + M_feature_level = torch.matmul(M, self.group_attention_matrix) + masked_x = torch.mul(M_feature_level, x) + out = self.feat_transformers[step](masked_x) + d = ReLU()(out[:, : self.n_d]) + steps_output.append(d) + # update attention + att = out[:, self.n_d :] + + M_loss /= self.n_steps + return steps_output, M_loss + + def forward_masks(self, x): + x = self.initial_bn(x) + bs = x.shape[0] # batch size + prior = torch.ones((bs, self.attention_dim)).to(x.device) + M_explain = torch.zeros(x.shape).to(x.device) + att = self.initial_splitter(x)[:, self.n_d :] + masks = {} + + for step in range(self.n_steps): + M = self.att_transformers[step](prior, att) + M_feature_level = torch.matmul(M, self.group_attention_matrix) + masks[step] = M_feature_level + # update prior + prior = torch.mul(self.gamma - M, prior) + # output + masked_x = torch.mul(M_feature_level, x) + out = self.feat_transformers[step](masked_x) + d = ReLU()(out[:, : self.n_d]) + # explain + step_importance = torch.sum(d, dim=1) + M_explain += torch.mul(M_feature_level, step_importance.unsqueeze(dim=1)) + # update attention + att = out[:, self.n_d :] + + return M_explain, masks + + +class TabNetDecoder(torch.nn.Module): + def __init__( + self, + input_dim, + n_d=8, + n_steps=3, + n_independent=1, + n_shared=1, + virtual_batch_size=128, + momentum=0.02, + ): + """ + Defines main part of the TabNet network without the embedding layers. + + Parameters + ---------- + input_dim : int + Number of features + output_dim : int or list of int for multi task classification + Dimension of network output + examples : one for regression, 2 for binary classification etc... + n_d : int + Dimension of the prediction layer (usually between 4 and 64) + n_steps : int + Number of successive steps in the network (usually between 3 and 10) + gamma : float + Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) + n_independent : int + Number of independent GLU layer in each GLU block (default 1) + n_shared : int + Number of independent GLU layer in each GLU block (default 1) + virtual_batch_size : int + Batch size for Ghost Batch Normalization + momentum : float + Float value between 0 and 1 which will be used for momentum in all batch norm + """ + super(TabNetDecoder, self).__init__() + self.input_dim = input_dim + self.n_d = n_d + self.n_steps = n_steps + self.n_independent = n_independent + self.n_shared = n_shared + self.virtual_batch_size = virtual_batch_size + + self.feat_transformers = torch.nn.ModuleList() + + if self.n_shared > 0: + shared_feat_transform = torch.nn.ModuleList() + for i in range(self.n_shared): + if i == 0: + shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) + else: + shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) + + else: + shared_feat_transform = None + + for step in range(n_steps): + transformer = FeatTransformer( + n_d, + n_d, + shared_feat_transform, + n_glu_independent=self.n_independent, + virtual_batch_size=self.virtual_batch_size, + momentum=momentum, + ) + self.feat_transformers.append(transformer) + + self.reconstruction_layer = Linear(n_d, self.input_dim, bias=False) + initialize_non_glu(self.reconstruction_layer, n_d, self.input_dim) + + def forward(self, steps_output): + res = 0 + for step_nb, step_output in enumerate(steps_output): + x = self.feat_transformers[step_nb](step_output) + res = torch.add(res, x) + res = self.reconstruction_layer(res) + return res + + +class TabNetPretraining(torch.nn.Module): + def __init__( + self, + input_dim, + pretraining_ratio=0.2, + n_d=8, + n_a=8, + n_steps=3, + gamma=1.3, + cat_idxs=[], + cat_dims=[], + cat_emb_dim=1, + n_independent=2, + n_shared=2, + epsilon=1e-15, + virtual_batch_size=128, + momentum=0.02, + mask_type="sparsemax", + n_shared_decoder=1, + n_indep_decoder=1, + group_attention_matrix=None, + ): + super(TabNetPretraining, self).__init__() + + self.cat_idxs = cat_idxs or [] + self.cat_dims = cat_dims or [] + self.cat_emb_dim = cat_emb_dim + + self.input_dim = input_dim + self.n_d = n_d + self.n_a = n_a + self.n_steps = n_steps + self.gamma = gamma + self.epsilon = epsilon + self.n_independent = n_independent + self.n_shared = n_shared + self.mask_type = mask_type + self.pretraining_ratio = pretraining_ratio + self.n_shared_decoder = n_shared_decoder + self.n_indep_decoder = n_indep_decoder + + if self.n_steps <= 0: + raise ValueError("n_steps should be a positive integer.") + if self.n_independent == 0 and self.n_shared == 0: + raise ValueError("n_shared and n_independent can't be both zero.") + + self.virtual_batch_size = virtual_batch_size + self.embedder = EmbeddingGenerator( + input_dim, cat_dims, cat_idxs, cat_emb_dim, group_attention_matrix + ) + self.post_embed_dim = self.embedder.post_embed_dim + + self.masker = RandomObfuscator( + self.pretraining_ratio, group_matrix=self.embedder.embedding_group_matrix + ) + self.encoder = TabNetEncoder( + input_dim=self.post_embed_dim, + output_dim=self.post_embed_dim, + n_d=n_d, + n_a=n_a, + n_steps=n_steps, + gamma=gamma, + n_independent=n_independent, + n_shared=n_shared, + epsilon=epsilon, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + mask_type=mask_type, + group_attention_matrix=self.embedder.embedding_group_matrix, + ) + self.decoder = TabNetDecoder( + self.post_embed_dim, + n_d=n_d, + n_steps=n_steps, + n_independent=self.n_indep_decoder, + n_shared=self.n_shared_decoder, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + ) + + def forward(self, x): + """ + Returns: res, embedded_x, obf_vars + res : output of reconstruction + embedded_x : embedded input + obf_vars : which variable where obfuscated + """ + embedded_x = self.embedder(x) + if self.training: + masked_x, obfuscated_groups, obfuscated_vars = self.masker(embedded_x) + # set prior of encoder with obfuscated groups + prior = 1 - obfuscated_groups + steps_out, _ = self.encoder(masked_x, prior=prior) + res = self.decoder(steps_out) + return res, embedded_x, obfuscated_vars + else: + steps_out, _ = self.encoder(embedded_x) + res = self.decoder(steps_out) + return res, embedded_x, torch.ones(embedded_x.shape).to(x.device) + + def forward_masks(self, x): + embedded_x = self.embedder(x) + return self.encoder.forward_masks(embedded_x) + + +class TabNetNoEmbeddings(torch.nn.Module): + def __init__( + self, + input_dim, + output_dim, + n_d=8, + n_a=8, + n_steps=3, + gamma=1.3, + n_independent=2, + n_shared=2, + epsilon=1e-15, + virtual_batch_size=128, + momentum=0.02, + mask_type="sparsemax", + group_attention_matrix=None, + ): + """ + Defines main part of the TabNet network without the embedding layers. + + Parameters + ---------- + input_dim : int + Number of features + output_dim : int or list of int for multi task classification + Dimension of network output + examples : one for regression, 2 for binary classification etc... + n_d : int + Dimension of the prediction layer (usually between 4 and 64) + n_a : int + Dimension of the attention layer (usually between 4 and 64) + n_steps : int + Number of successive steps in the network (usually between 3 and 10) + gamma : float + Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) + n_independent : int + Number of independent GLU layer in each GLU block (default 2) + n_shared : int + Number of independent GLU layer in each GLU block (default 2) + epsilon : float + Avoid log(0), this should be kept very low + virtual_batch_size : int + Batch size for Ghost Batch Normalization + momentum : float + Float value between 0 and 1 which will be used for momentum in all batch norm + mask_type : str + Either "sparsemax" or "entmax" : this is the masking function to use + group_attention_matrix : torch matrix + Matrix of size (n_groups, input_dim), m_ij = importance within group i of feature j + """ + super(TabNetNoEmbeddings, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.is_multi_task = isinstance(output_dim, list) + self.n_d = n_d + self.n_a = n_a + self.n_steps = n_steps + self.gamma = gamma + self.epsilon = epsilon + self.n_independent = n_independent + self.n_shared = n_shared + self.virtual_batch_size = virtual_batch_size + self.mask_type = mask_type + self.initial_bn = BatchNorm1d(self.input_dim, momentum=0.01) + + self.encoder = TabNetEncoder( + input_dim=input_dim, + output_dim=output_dim, + n_d=n_d, + n_a=n_a, + n_steps=n_steps, + gamma=gamma, + n_independent=n_independent, + n_shared=n_shared, + epsilon=epsilon, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + mask_type=mask_type, + group_attention_matrix=group_attention_matrix, + ) + + if self.is_multi_task: + self.multi_task_mappings = torch.nn.ModuleList() + for task_dim in output_dim: + task_mapping = Linear(n_d, task_dim, bias=False) + initialize_non_glu(task_mapping, n_d, task_dim) + self.multi_task_mappings.append(task_mapping) + else: + self.final_mapping = Linear(n_d, output_dim, bias=False) + initialize_non_glu(self.final_mapping, n_d, output_dim) + + def forward(self, x): + res = 0 + steps_output, M_loss = self.encoder(x) + res = torch.sum(torch.stack(steps_output, dim=0), dim=0) + + if self.is_multi_task: + # Result will be in list format + out = [] + for task_mapping in self.multi_task_mappings: + out.append(task_mapping(res)) + else: + out = self.final_mapping(res) + return out, M_loss + + def forward_masks(self, x): + return self.encoder.forward_masks(x) + + +class TabNet(torch.nn.Module): + def __init__( + self, + input_dim, + output_dim, + n_d=8, + n_a=8, + n_steps=3, + gamma=1.3, + cat_idxs=[], + cat_dims=[], + cat_emb_dim=1, + n_independent=2, + n_shared=2, + epsilon=1e-15, + virtual_batch_size=128, + momentum=0.02, + mask_type="sparsemax", + group_attention_matrix=[], + ): + """ + Defines TabNet network + + Parameters + ---------- + input_dim : int + Initial number of features + output_dim : int + Dimension of network output + examples : one for regression, 2 for binary classification etc... + n_d : int + Dimension of the prediction layer (usually between 4 and 64) + n_a : int + Dimension of the attention layer (usually between 4 and 64) + n_steps : int + Number of successive steps in the network (usually between 3 and 10) + gamma : float + Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) + cat_idxs : list of int + Index of each categorical column in the dataset + cat_dims : list of int + Number of categories in each categorical column + cat_emb_dim : int or list of int + Size of the embedding of categorical features + if int, all categorical features will have same embedding size + if list of int, every corresponding feature will have specific size + n_independent : int + Number of independent GLU layer in each GLU block (default 2) + n_shared : int + Number of independent GLU layer in each GLU block (default 2) + epsilon : float + Avoid log(0), this should be kept very low + virtual_batch_size : int + Batch size for Ghost Batch Normalization + momentum : float + Float value between 0 and 1 which will be used for momentum in all batch norm + mask_type : str + Either "sparsemax" or "entmax" : this is the masking function to use + group_attention_matrix : torch matrix + Matrix of size (n_groups, input_dim), m_ij = importance within group i of feature j + """ + super(TabNet, self).__init__() + self.cat_idxs = cat_idxs or [] + self.cat_dims = cat_dims or [] + self.cat_emb_dim = cat_emb_dim + + self.input_dim = input_dim + self.output_dim = output_dim + self.n_d = n_d + self.n_a = n_a + self.n_steps = n_steps + self.gamma = gamma + self.epsilon = epsilon + self.n_independent = n_independent + self.n_shared = n_shared + self.mask_type = mask_type + + if self.n_steps <= 0: + raise ValueError("n_steps should be a positive integer.") + if self.n_independent == 0 and self.n_shared == 0: + raise ValueError("n_shared and n_independent can't be both zero.") + + self.virtual_batch_size = virtual_batch_size + self.embedder = EmbeddingGenerator( + input_dim, cat_dims, cat_idxs, cat_emb_dim, group_attention_matrix + ) + self.post_embed_dim = self.embedder.post_embed_dim + + self.tabnet = TabNetNoEmbeddings( + self.post_embed_dim, + output_dim, + n_d, + n_a, + n_steps, + gamma, + n_independent, + n_shared, + epsilon, + virtual_batch_size, + momentum, + mask_type, + self.embedder.embedding_group_matrix, + ) + + def forward(self, x): + x = self.embedder(x) + return self.tabnet(x) + + def forward_masks(self, x): + x = self.embedder(x) + return self.tabnet.forward_masks(x) + + +class AttentiveTransformer(torch.nn.Module): + def __init__( + self, + input_dim, + group_dim, + group_matrix, + virtual_batch_size=128, + momentum=0.02, + mask_type="sparsemax", + ): + """ + Initialize an attention transformer. + + Parameters + ---------- + input_dim : int + Input size + group_dim : int + Number of groups for features + virtual_batch_size : int + Batch size for Ghost Batch Normalization + momentum : float + Float value between 0 and 1 which will be used for momentum in batch norm + mask_type : str + Either "sparsemax" or "entmax" : this is the masking function to use + """ + super(AttentiveTransformer, self).__init__() + self.fc = Linear(input_dim, group_dim, bias=False) + initialize_non_glu(self.fc, input_dim, group_dim) + self.bn = GBN( + group_dim, virtual_batch_size=virtual_batch_size, momentum=momentum + ) + + if mask_type == "sparsemax": + # Sparsemax + self.selector = sparsemax.Sparsemax(dim=-1) + elif mask_type == "entmax": + # Entmax + self.selector = sparsemax.Entmax15(dim=-1) + else: + raise NotImplementedError( + "Please choose either sparsemax" + "or entmax as masktype" + ) + + def forward(self, priors, processed_feat): + x = self.fc(processed_feat) + x = self.bn(x) + x = torch.mul(x, priors) + x = self.selector(x) + return x + + +class FeatTransformer(torch.nn.Module): + def __init__( + self, + input_dim, + output_dim, + shared_layers, + n_glu_independent, + virtual_batch_size=128, + momentum=0.02, + ): + super(FeatTransformer, self).__init__() + """ + Initialize a feature transformer. + + Parameters + ---------- + input_dim : int + Input size + output_dim : int + Output_size + shared_layers : torch.nn.ModuleList + The shared block that should be common to every step + n_glu_independent : int + Number of independent GLU layers + virtual_batch_size : int + Batch size for Ghost Batch Normalization within GLU block(s) + momentum : float + Float value between 0 and 1 which will be used for momentum in batch norm + """ + + params = { + "n_glu": n_glu_independent, + "virtual_batch_size": virtual_batch_size, + "momentum": momentum, + } + + if shared_layers is None: + # no shared layers + self.shared = torch.nn.Identity() + is_first = True + else: + self.shared = GLU_Block( + input_dim, + output_dim, + first=True, + shared_layers=shared_layers, + n_glu=len(shared_layers), + virtual_batch_size=virtual_batch_size, + momentum=momentum, + ) + is_first = False + + if n_glu_independent == 0: + # no independent layers + self.specifics = torch.nn.Identity() + else: + spec_input_dim = input_dim if is_first else output_dim + self.specifics = GLU_Block( + spec_input_dim, output_dim, first=is_first, **params + ) + + def forward(self, x): + x = self.shared(x) + x = self.specifics(x) + return x + + +class GLU_Block(torch.nn.Module): + """ + Independent GLU block, specific to each step + """ + + def __init__( + self, + input_dim, + output_dim, + n_glu=2, + first=False, + shared_layers=None, + virtual_batch_size=128, + momentum=0.02, + ): + super(GLU_Block, self).__init__() + self.first = first + self.shared_layers = shared_layers + self.n_glu = n_glu + self.glu_layers = torch.nn.ModuleList() + + params = {"virtual_batch_size": virtual_batch_size, "momentum": momentum} + + fc = shared_layers[0] if shared_layers else None + self.glu_layers.append(GLU_Layer(input_dim, output_dim, fc=fc, **params)) + for glu_id in range(1, self.n_glu): + fc = shared_layers[glu_id] if shared_layers else None + self.glu_layers.append(GLU_Layer(output_dim, output_dim, fc=fc, **params)) + + def forward(self, x): + scale = torch.sqrt(torch.FloatTensor([0.5]).to(x.device)) + if self.first: # the first layer of the block has no scale multiplication + x = self.glu_layers[0](x) + layers_left = range(1, self.n_glu) + else: + layers_left = range(self.n_glu) + + for glu_id in layers_left: + x = torch.add(x, self.glu_layers[glu_id](x)) + x = x * scale + return x + + +class GLU_Layer(torch.nn.Module): + def __init__( + self, input_dim, output_dim, fc=None, virtual_batch_size=128, momentum=0.02 + ): + super(GLU_Layer, self).__init__() + + self.output_dim = output_dim + if fc: + self.fc = fc + else: + self.fc = Linear(input_dim, 2 * output_dim, bias=False) + initialize_glu(self.fc, input_dim, 2 * output_dim) + + self.bn = GBN( + 2 * output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum + ) + + def forward(self, x): + x = self.fc(x) + x = self.bn(x) + out = torch.mul(x[:, : self.output_dim], torch.sigmoid(x[:, self.output_dim :])) + return out + + +class EmbeddingGenerator(torch.nn.Module): + """ + Classical embeddings generator + """ + + def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dims, group_matrix): + """This is an embedding module for an entire set of features + + Parameters + ---------- + input_dim : int + Number of features coming as input (number of columns) + cat_dims : list of int + Number of modalities for each categorial features + If the list is empty, no embeddings will be done + cat_idxs : list of int + Positional index for each categorical features in inputs + cat_emb_dim : list of int + Embedding dimension for each categorical features + If int, the same embedding dimension will be used for all categorical features + group_matrix : torch matrix + Original group matrix before embeddings + """ + super(EmbeddingGenerator, self).__init__() + + if cat_dims == [] and cat_idxs == []: + self.skip_embedding = True + self.post_embed_dim = input_dim + self.embedding_group_matrix = group_matrix.to(group_matrix.device) + return + else: + self.skip_embedding = False + + self.post_embed_dim = int(input_dim + np.sum(cat_emb_dims) - len(cat_emb_dims)) + + self.embeddings = torch.nn.ModuleList() + + for cat_dim, emb_dim in zip(cat_dims, cat_emb_dims): + self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim)) + + # record continuous indices + self.continuous_idx = torch.ones(input_dim, dtype=torch.bool) + self.continuous_idx[cat_idxs] = 0 + + # update group matrix + n_groups = group_matrix.shape[0] + self.embedding_group_matrix = torch.empty( + (n_groups, self.post_embed_dim), device=group_matrix.device + ) + for group_idx in range(n_groups): + post_emb_idx = 0 + cat_feat_counter = 0 + for init_feat_idx in range(input_dim): + if self.continuous_idx[init_feat_idx] == 1: + # this means that no embedding is applied to this column + self.embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[ + group_idx, init_feat_idx + ] # noqa + post_emb_idx += 1 + else: + # this is a categorical feature which creates multiple embeddings + n_embeddings = cat_emb_dims[cat_feat_counter] + self.embedding_group_matrix[ + group_idx, post_emb_idx : post_emb_idx + n_embeddings + ] = ( + group_matrix[group_idx, init_feat_idx] / n_embeddings + ) # noqa + post_emb_idx += n_embeddings + cat_feat_counter += 1 + + def forward(self, x): + """ + Apply embeddings to inputs + Inputs should be (batch_size, input_dim) + Outputs will be of size (batch_size, self.post_embed_dim) + """ + if self.skip_embedding: + # no embeddings required + return x + + cols = [] + cat_feat_counter = 0 + for feat_init_idx, is_continuous in enumerate(self.continuous_idx): + # Enumerate through continuous idx boolean mask to apply embeddings + if is_continuous: + cols.append(x[:, feat_init_idx].float().view(-1, 1)) + else: + cols.append( + self.embeddings[cat_feat_counter](x[:, feat_init_idx].long()) + ) + cat_feat_counter += 1 + # concat + post_embeddings = torch.cat(cols, dim=1) + return post_embeddings + + +class RandomObfuscator(torch.nn.Module): + """ + Create and applies obfuscation masks. + The obfuscation is done at group level to match attention. + """ + + def __init__(self, pretraining_ratio, group_matrix): + """ + This create random obfuscation for self suppervised pretraining + Parameters + ---------- + pretraining_ratio : float + Ratio of feature to randomly discard for reconstruction + + """ + super(RandomObfuscator, self).__init__() + self.pretraining_ratio = pretraining_ratio + # group matrix is set to boolean here to pass all posssible information + self.group_matrix = (group_matrix > 0) + 0.0 + self.num_groups = group_matrix.shape[0] + + def forward(self, x): + """ + Generate random obfuscation mask. + + Returns + ------- + masked input and obfuscated variables. + """ + bs = x.shape[0] + + obfuscated_groups = torch.bernoulli( + self.pretraining_ratio * torch.ones((bs, self.num_groups), device=x.device) + ) + obfuscated_vars = torch.matmul(obfuscated_groups, self.group_matrix) + masked_input = torch.mul(1 - obfuscated_vars, x) + return masked_input, obfuscated_groups, obfuscated_vars From d29ef37a94d6123a7ec0b3250799164b574a1fdc Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 2 Apr 2023 21:13:24 +0200 Subject: [PATCH 37/95] add factory.py, let DDPM use TabNet, refactor --- .gitignore | 3 +- src/synthcity/plugins/core/dataloader.py | 2 +- src/synthcity/plugins/core/models/convnet.py | 4 +- src/synthcity/plugins/core/models/factory.py | 137 ++ .../{data_encoder.py => feature_encoder.py} | 39 +- .../plugins/core/models/functions.py | 152 ++ src/synthcity/plugins/core/models/layers.py | 136 +- src/synthcity/plugins/core/models/mlp.py | 145 +- src/synthcity/plugins/core/models/tabnet.py | 1246 ++++++++--------- .../core/models/tabular_ddpm/__init__.py | 14 +- .../gaussian_multinomial_diffsuion.py | 36 +- .../core/models/tabular_ddpm/modules.py | 31 +- .../plugins/core/models/tabular_encoder.py | 10 +- src/synthcity/plugins/core/models/ts_model.py | 4 +- src/synthcity/plugins/generic/plugin_ddpm.py | 30 +- tests/plugins/core/models/test_mlp.py | 8 +- tests/plugins/generic/test_ddpm.py | 6 +- 17 files changed, 1077 insertions(+), 926 deletions(-) create mode 100644 src/synthcity/plugins/core/models/factory.py rename src/synthcity/plugins/core/models/{data_encoder.py => feature_encoder.py} (91%) create mode 100644 src/synthcity/plugins/core/models/functions.py diff --git a/.gitignore b/.gitignore index c41e0784..5195f6c2 100644 --- a/.gitignore +++ b/.gitignore @@ -67,5 +67,4 @@ lightning_logs generated MNIST cifar-10* -src/test.py -.tmp.py +local_test.py diff --git a/src/synthcity/plugins/core/dataloader.py b/src/synthcity/plugins/core/dataloader.py index 099e85c5..2a13424b 100644 --- a/src/synthcity/plugins/core/dataloader.py +++ b/src/synthcity/plugins/core/dataloader.py @@ -16,7 +16,7 @@ # synthcity absolute from synthcity.plugins.core.constraints import Constraints from synthcity.plugins.core.dataset import FlexibleDataset, TensorDataset -from synthcity.plugins.core.models.data_encoder import DatetimeEncoder +from synthcity.plugins.core.models.feature_encoder import DatetimeEncoder from synthcity.utils.compression import compress_dataset, decompress_dataset from synthcity.utils.serialization import dataframe_hash diff --git a/src/synthcity/plugins/core/models/convnet.py b/src/synthcity/plugins/core/models/convnet.py index e9a7719c..ae4260e6 100644 --- a/src/synthcity/plugins/core/models/convnet.py +++ b/src/synthcity/plugins/core/models/convnet.py @@ -69,8 +69,8 @@ class ConvNet(nn.Module): @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__( self, - task_type: str, - model: nn.Module, # classification/regression + task_type: str, # classification/regression + model: nn.Module, lr: float = 1e-3, weight_decay: float = 1e-3, opt_betas: tuple = (0.9, 0.999), diff --git a/src/synthcity/plugins/core/models/factory.py b/src/synthcity/plugins/core/models/factory.py new file mode 100644 index 00000000..e2d69525 --- /dev/null +++ b/src/synthcity/plugins/core/models/factory.py @@ -0,0 +1,137 @@ +# stdlib +from importlib import import_module +from typing import Any, Union + +# third party +from pydantic import validate_arguments +from torch import nn + +# synthcity relative +from .feature_encoder import ( + BayesianGMMEncoder, + DatetimeEncoder, + FeatureEncoder, + GaussianQuantileTransformer, + LabelEncoder, + MinMaxScaler, + OneHotEncoder, + RobustScaler, + StandardScaler, +) +from .layers import GumbelSoftmax + +# should only contain nn modules that can be used as building blocks in larger models +MODELS = dict( + mlp=".mlp.MLP", + rnn=nn.RNN, + gru=nn.GRU, + lstm=nn.LSTM, + transformer=".transformer.TransformerModel", + tabnet=".tabnet.TabNet", +) + +ACTIVATIONS = dict( + none=nn.Identity, + elu=nn.ELU, + relu=nn.ReLU, + leakyrelu=nn.LeakyReLU, + selu=nn.SELU, + tanh=nn.Tanh, + sigmoid=nn.Sigmoid, + softmax=nn.Softmax, + gumbelsoftmax=GumbelSoftmax, + gelu=nn.GELU, + silu=nn.SiLU, + swish=nn.SiLU, + hardtanh=nn.Hardtanh, + relu6=nn.ReLU6, + celu=nn.CELU, + glu=nn.GLU, + logsigmoid=nn.LogSigmoid, + softplus=nn.Softplus, +) + +FEATURE_ENCODERS = dict( + datetime=DatetimeEncoder, + onehot=OneHotEncoder, + label=LabelEncoder, + standard=StandardScaler, + minmax=MinMaxScaler, + robust=RobustScaler, + quantile=GaussianQuantileTransformer, + bayesiangmm=BayesianGMMEncoder, + none=FeatureEncoder, + passthrough=FeatureEncoder, +) + + +def _factory(type_: Union[str, type], params: dict, registry: dict) -> Any: + if isinstance(type_, type): + return type_(**params) + type_ = type_.lower().replace("_", "").replace("-", "") + if type_ in registry: + cls = registry[type_] + if isinstance(cls, str): + cls = registry[type_] = _dynamic_import(cls) + return cls(**params) + raise ValueError + + +def _dynamic_import(path: str) -> type: + """Avoid circular imports by importing dynamically.""" + if path.startswith("."): + package = __name__.rsplit(".", 1)[0] + else: + package = None + mod_path, cls = path.rsplit(".", 1) + module = import_module(mod_path, package) + return getattr(module, cls) + + +@validate_arguments(config=dict(arbitrary_types_allowed=True)) +def get_model(block: Union[str, type], params: dict) -> Any: + """Get a model from a name or a class. + + Named models: + - mlp + - rnn + - lstm + - transformer + - tabnet + """ + try: + return _factory(block, params, MODELS) + except ValueError: + raise ValueError(f"Unknown nn model: {block}") + + +@validate_arguments(config=dict(arbitrary_types_allowed=True)) +def get_nonlin(nonlin: Union[str, nn.Module], params: dict = {}) -> Any: + """Get a nonlinearity layer from a name or a class.""" + try: + return _factory(nonlin, params, ACTIVATIONS) + except ValueError: + raise ValueError(f"Unknown nonlinearity: {nonlin}") + + +@validate_arguments(config=dict(arbitrary_types_allowed=True)) +def get_feature_encoder(encoder: Union[str, type], params: dict = {}) -> Any: + """Get a feature encoder from a name or a class. + + Named encoders: + - datetime + - onehot + - label + - standard + - minmax + - robust + - quantile + - bayesian_gmm + - passthrough + """ + if isinstance(encoder, type): # custom encoder + encoder = FeatureEncoder.wraps(encoder) + try: + return _factory(encoder, params, FEATURE_ENCODERS) + except ValueError: + raise ValueError(f"Unknown feature encoder: {encoder}") diff --git a/src/synthcity/plugins/core/models/data_encoder.py b/src/synthcity/plugins/core/models/feature_encoder.py similarity index 91% rename from src/synthcity/plugins/core/models/data_encoder.py rename to src/synthcity/plugins/core/models/feature_encoder.py index 518400fa..995a65a6 100644 --- a/src/synthcity/plugins/core/models/data_encoder.py +++ b/src/synthcity/plugins/core/models/feature_encoder.py @@ -119,7 +119,7 @@ def _inverse_transform(self, data: np.ndarray) -> np.ndarray: @classmethod def wraps( - cls, encoder_class: TransformerMixin, **params: Any + cls: type, encoder_class: TransformerMixin, **params: Any ) -> Type[FeatureEncoder]: """Wraps sklearn transformer to FeatureEncoder.""" @@ -260,7 +260,7 @@ def __init__( subsample: int = 10000, random_state: Any = None, copy: bool = True, - ): + ) -> None: super().__init__( n_quantiles=None, output_distribution="normal", @@ -273,38 +273,3 @@ def __init__( def fit(self, x: np.ndarray, y: Any = None) -> "GaussianQuantileTransformer": self.n_quantiles = max(min(len(x) // 30, 1000), 10) return super().fit(x, y) - - -ENCODERS = { - "datetime": DatetimeEncoder, - "onehot": OneHotEncoder, - "label": LabelEncoder, - "standard": StandardScaler, - "minmax": MinMaxScaler, - "robust": RobustScaler, - "quantile": GaussianQuantileTransformer, - "bayesian_gmm": BayesianGMMEncoder, - "passthrough": FeatureEncoder, -} - - -def get_encoder(encoder: Union[str, type]) -> Type[FeatureEncoder]: - """Get a registered encoder. - - Supported encoders: - - Datetime - - datetime - - Categorical - - onehot - - label - - Continuous - - standard - - minmax - - robust - - quantile - - bayesian_gmm - - Passthrough - """ - if isinstance(encoder, type): # custom encoder - return FeatureEncoder.wraps(encoder) - return ENCODERS[encoder] diff --git a/src/synthcity/plugins/core/models/functions.py b/src/synthcity/plugins/core/models/functions.py new file mode 100644 index 00000000..27801034 --- /dev/null +++ b/src/synthcity/plugins/core/models/functions.py @@ -0,0 +1,152 @@ +""" +Custom differentiable tensor functions. +""" +# stdlib +from typing import Any + +# third party +import torch +from torch.autograd import Function + + +# credits to Yandex https://github.com/Qwicen/node/blob/master/lib/nn_utils.py +def _make_ix_like(input: torch.Tensor, dim: int = 0) -> torch.Tensor: + d = input.size(dim) + rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) + view = [1] * input.dim() + view[0] = -1 + return rho.view(view).transpose(0, dim) + + +class SparsemaxFunction(Function): + """ + An implementation of sparsemax (Martins & Astudillo, 2016). See + :cite:`DBLP:journals/corr/MartinsA16` for detailed description. + By Ben Peters and Vlad Niculae + """ + + @staticmethod + def forward( + ctx: Any, + input: torch.Tensor, + dim: int = -1, + ) -> torch.Tensor: + """sparsemax: normalizing sparse transform (a la softmax) + + Parameters + ---------- + ctx : torch.autograd.function._ContextMethodMixin + input : torch.Tensor + any shape + dim : int + dimension along which to apply sparsemax + + Returns + ------- + output : torch.Tensor + same shape as input + + """ + ctx.dim = dim + max_val, _ = input.max(dim=dim, keepdim=True) + input -= max_val # same numerical stability trick as for softmax + tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) + output = torch.clamp(input - tau, min=0) + ctx.save_for_backward(supp_size, output) + return output + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: + supp_size, output = ctx.saved_tensors + dim = ctx.dim + grad_input = grad_output.clone() + grad_input[output == 0] = 0 + + v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() + v_hat = v_hat.unsqueeze(dim) + grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) + return grad_input, None + + @staticmethod + def _threshold_and_support( + input: torch.Tensor, dim: int = -1 + ) -> tuple[torch.Tensor, torch.Tensor]: + """Sparsemax building block: compute the threshold + + Parameters + ---------- + input: torch.Tensor + any dimension + dim : int + dimension along which to apply the sparsemax + + Returns + ------- + tau : torch.Tensor + the threshold value + support_size : torch.Tensor + + """ + + input_srt, _ = torch.sort(input, descending=True, dim=dim) + input_cumsum = input_srt.cumsum(dim) - 1 + rhos = _make_ix_like(input, dim) + support = rhos * input_srt > input_cumsum + + support_size = support.sum(dim=dim).unsqueeze(dim) + tau = input_cumsum.gather(dim, support_size - 1) + tau /= support_size.to(input.dtype) + return tau, support_size + + +class EntmaxFunction(Function): + """ + An implementation of exact Entmax with alpha=1.5 (B. Peters, V. Niculae, A. Martins). See + :cite:`https://arxiv.org/abs/1905.05702 for detailed description. + Source: https://github.com/deep-spin/entmax + """ + + @staticmethod + def forward(ctx: Any, input: torch.Tensor, dim: int = -1) -> torch.Tensor: + ctx.dim = dim + + max_val, _ = input.max(dim=dim, keepdim=True) + input = input - max_val # same numerical stability trick as for softmax + input = input / 2 # divide by 2 to solve actual Entmax + + tau_star, _ = EntmaxFunction._threshold_and_support(input, dim) + output = torch.clamp(input - tau_star, min=0) ** 2 + ctx.save_for_backward(output) + return output + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: + (Y,) = ctx.saved_tensors + gppr = Y.sqrt() # = 1 / g'' (Y) + dX = grad_output * gppr + q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) + q = q.unsqueeze(ctx.dim) + dX -= q * gppr + return dX, None + + @staticmethod + def _threshold_and_support( + input: torch.Tensor, dim: int = -1 + ) -> tuple[torch.Tensor, torch.Tensor]: + Xsrt, _ = torch.sort(input, descending=True, dim=dim) + + rho = _make_ix_like(input, dim) + mean = Xsrt.cumsum(dim) / rho + mean_sq = (Xsrt**2).cumsum(dim) / rho + ss = rho * (mean_sq - mean**2) + delta = (1 - ss) / rho + + # NOTE this is not exactly the same as in reference algo + # Fortunately it seems the clamped values never wrongly + # get selected by tau <= sorted_z. Prove this! + delta_nz = torch.clamp(delta, 0) + tau = mean - torch.sqrt(delta_nz) + + support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim) + tau_star = tau.gather(dim, support_size - 1) + return tau_star, support_size diff --git a/src/synthcity/plugins/core/models/layers.py b/src/synthcity/plugins/core/models/layers.py index fb4a8ea6..9be97abd 100644 --- a/src/synthcity/plugins/core/models/layers.py +++ b/src/synthcity/plugins/core/models/layers.py @@ -1,10 +1,18 @@ # stdlib -from typing import Any, Optional +from typing import Any, List, Optional, Tuple, Type # third party +import numpy as np import torch +from pydantic import validate_arguments from torch import nn +# synthcity absolute +from synthcity.utils.constants import DEVICE + +# synthcity relative +from .functions import EntmaxFunction, SparsemaxFunction + class Permute(nn.Module): def __init__(self, *dims: Any) -> None: @@ -34,3 +42,129 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.transpose(*self.dims).contiguous() else: return x.transpose(*self.dims) + + +def SkipConnection(cls: Type[nn.Module]) -> Type[nn.Module]: + """Wraps a model to add a skip connection from the input to the output. + + Example: + >>> ResidualBlock = SkipConnection(MLP) + >>> ResidualBlock(n_units_in=10, n_units_out=3, n_units_hidden=64) + SkipConnection(MLP)( + (model): Sequential( + (0): LinearLayer( + (model): Sequential( + (0): Linear(in_features=10, out_features=64, bias=True) + (1): ReLU() + ) + ) + (1): Linear(in_features=64, out_features=3, bias=True) + ) + (loss): MSELoss() + ) + """ + + class WrappedModule(cls): # type: ignore + device: torch.device = DEVICE + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def forward(self, X: torch.Tensor) -> torch.Tensor: + # if X.shape[-1] == 0: + # return torch.zeros((*X.shape[:-1], self.n_units_out)).to(self.device) + X = X.float().to(self.device) + out = super().forward(X) + return torch.cat([out, X], dim=-1) + + WrappedModule.__name__ = f"SkipConnection({cls.__name__})" + WrappedModule.__qualname__ = f"SkipConnection({cls.__qualname__})" + WrappedModule.__doc__ = f"""(With skipped connection) {cls.__doc__}""" + return WrappedModule + + +# class GLU(nn.Module): +# """Gated Linear Unit (GLU).""" + +# def __init__(self, activation: Union[str, nn.Module] = "sigmoid") -> None: +# super().__init__() +# if type(activation) == str: +# self.non_lin = get_nonlin(activation) +# else: +# self.non_lin = activation + +# def forward(self, x: Tensor) -> Tensor: +# if x.shape[-1] % 2: +# raise ValueError("The last dimension of the input tensor must be even.") +# a, b = x.chunk(2, dim=-1) +# return a * self.non_lin(b) + + +class GumbelSoftmax(nn.Module): + def __init__( + self, tau: float = 0.2, hard: bool = False, eps: float = 1e-10, dim: int = -1 + ) -> None: + super(GumbelSoftmax, self).__init__() + + self.tau = tau + self.hard = hard + self.eps = eps + self.dim = dim + + def forward(self, logits: torch.Tensor) -> torch.Tensor: + return nn.functional.gumbel_softmax( + logits, tau=self.tau, hard=self.hard, eps=self.eps, dim=self.dim + ) + + +class MultiActivationHead(nn.Module): + """Final layer with multiple activations. Useful for tabular data.""" + + def __init__( + self, + activations: List[Tuple[nn.Module, int]], + device: Any = DEVICE, + ) -> None: + super(MultiActivationHead, self).__init__() + self.activations = [] + self.activation_lengths = [] + self.device = device + + for activation, length in activations: + self.activations.append(activation) + self.activation_lengths.append(length) + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def forward(self, X: torch.Tensor) -> torch.Tensor: + if X.shape[-1] != np.sum(self.activation_lengths): + raise RuntimeError( + f"Shape mismatch for the activations: expected {np.sum(self.activation_lengths)}. Got shape {X.shape}." + ) + + split = 0 + out = torch.zeros(X.shape).to(self.device) + + for activation, step in zip(self.activations, self.activation_lengths): + out[..., split : split + step] = activation(X[..., split : split + step]) + + split += step + + return out + + +class Sparsemax(nn.Module): + def __init__(self, dim: int = -1) -> None: + super(Sparsemax, self).__init__() + self.dim = dim + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def forward(self, input: torch.Tensor) -> torch.Tensor: + return SparsemaxFunction.apply(input, self.dim) + + +class Entmax(nn.Module): + def __init__(self, dim: int = -1) -> None: + super(Entmax, self).__init__() + self.dim = dim + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def forward(self, input: torch.Tensor) -> torch.Tensor: + return EntmaxFunction.apply(input, self.dim) diff --git a/src/synthcity/plugins/core/models/mlp.py b/src/synthcity/plugins/core/models/mlp.py index 5ab63464..5d85c1c8 100644 --- a/src/synthcity/plugins/core/models/mlp.py +++ b/src/synthcity/plugins/core/models/mlp.py @@ -1,86 +1,25 @@ # stdlib -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple # third party import numpy as np import torch from pydantic import validate_arguments -from torch import Tensor, nn +from torch import nn from torch.utils.data import DataLoader, TensorDataset # synthcity absolute import synthcity.logger as log +from synthcity.plugins.core.models.factory import get_nonlin +from synthcity.plugins.core.models.layers import ( + GumbelSoftmax, + MultiActivationHead, + SkipConnection, +) from synthcity.utils.constants import DEVICE from synthcity.utils.reproducibility import enable_reproducible_results -class GumbelSoftmax(nn.Module): - def __init__( - self, tau: float = 0.2, hard: bool = False, eps: float = 1e-10, dim: int = -1 - ) -> None: - super(GumbelSoftmax, self).__init__() - - self.tau = tau - self.hard = hard - self.eps = eps - self.dim = dim - - def forward(self, logits: torch.Tensor) -> torch.Tensor: - return nn.functional.gumbel_softmax( - logits, tau=self.tau, hard=self.hard, eps=self.eps, dim=self.dim - ) - - -class GLU(nn.Module): - """Gated Linear Unit (GLU).""" - - def __init__(self, activation: Union[str, nn.Module] = "sigmoid") -> None: - super().__init__() - if type(activation) == str: - self.non_lin = get_nonlin(activation) - else: - self.non_lin = activation - - def forward(self, x: Tensor) -> Tensor: - if x.shape[-1] % 2: - raise ValueError("The last dimension of the input tensor must be even.") - a, b = x.chunk(2, dim=-1) - return a * self.non_lin(b) - - -def get_nonlin(name: Union[str, nn.Module]) -> nn.Module: - if isinstance(name, nn.Module): - return name - elif name == "none": - return nn.Identity() - elif name == "elu": - return nn.ELU() - elif name == "relu": - return nn.ReLU() - elif name == "leaky_relu": - return nn.LeakyReLU() - elif name == "selu": - return nn.SELU() - elif name == "tanh": - return nn.Tanh() - elif name == "sigmoid": - return nn.Sigmoid() - elif name == "softmax": - return GumbelSoftmax() - elif name == "gelu": - return nn.GELU() - elif name == "glu": - return GLU() - elif name == "reglu": - return GLU("relu") - elif name == "geglu": - return GLU("gelu") - elif name in ("silu", "swish"): - return nn.SiLU() - else: - raise ValueError(f"Unknown nonlinearity {name}") - - class LinearLayer(nn.Module): @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__( @@ -114,70 +53,7 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: return self.model(X.float()).to(self.device) -class ResidualLayer(LinearLayer): - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def __init__( - self, - n_units_in: int, - n_units_out: int, - dropout: float = 0, - batch_norm: bool = False, - nonlin: Optional[str] = "relu", - device: Any = DEVICE, - ) -> None: - super(ResidualLayer, self).__init__( - n_units_in, - n_units_out, - dropout=dropout, - batch_norm=batch_norm, - nonlin=nonlin, - device=device, - ) - self.device = device - self.n_units_out = n_units_out - - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def forward(self, X: torch.Tensor) -> torch.Tensor: - if X.shape[-1] == 0: - return torch.zeros((*X.shape[:-1], self.n_units_out)).to(self.device) - - out = self.model(X.float()) - return torch.cat([out, X], dim=-1).to(self.device) - - -class MultiActivationHead(nn.Module): - """Final layer with multiple activations. Useful for tabular data.""" - - def __init__( - self, - activations: List[Tuple[nn.Module, int]], - device: Any = DEVICE, - ) -> None: - super(MultiActivationHead, self).__init__() - self.activations = [] - self.activation_lengths = [] - self.device = device - - for activation, length in activations: - self.activations.append(activation) - self.activation_lengths.append(length) - - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def forward(self, X: torch.Tensor) -> torch.Tensor: - if X.shape[-1] != np.sum(self.activation_lengths): - raise RuntimeError( - f"Shape mismatch for the activations: expected {np.sum(self.activation_lengths)}. Got shape {X.shape}." - ) - - split = 0 - out = torch.zeros(X.shape).to(self.device) - - for activation, step in zip(self.activations, self.activation_lengths): - out[..., split : split + step] = activation(X[..., split : split + step]) - - split += step - - return out +ResidualLayer = SkipConnection(LinearLayer) class MLP(nn.Module): @@ -235,9 +111,10 @@ class MLP(nn.Module): @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__( self, - task_type: str, # classification/regression + *, n_units_in: int, n_units_out: int, + task_type: str = "regression", # classification/regression n_layers_hidden: int = 1, n_units_hidden: int = 100, nonlin: str = "relu", diff --git a/src/synthcity/plugins/core/models/tabnet.py b/src/synthcity/plugins/core/models/tabnet.py index 25383cb9..5a4f3051 100644 --- a/src/synthcity/plugins/core/models/tabnet.py +++ b/src/synthcity/plugins/core/models/tabnet.py @@ -1,105 +1,215 @@ +# stdlib +from typing import List, Optional, Tuple + # third party import numpy as np import torch -from torch.autograd import Function from torch.nn import BatchNorm1d, Linear, ReLU +# synthcity relative +from .layers import Entmax, Sparsemax + +# class TabNet(torch.nn.Module): +# def __init__( +# self, +# input_dim, +# output_dim, +# n_d=8, +# n_a=8, +# n_steps=3, +# gamma=1.3, +# n_independent=2, +# n_shared=2, +# epsilon=1e-15, +# virtual_batch_size=128, +# momentum=0.02, +# mask_type="sparsemax", +# group_attention_matrix=None, +# ): +# """ +# Defines TabNet network + +# Parameters +# ---------- +# input_dim : int +# Initial number of features +# output_dim : int +# Dimension of network output +# examples : one for regression, 2 for binary classification etc... +# n_d : int +# Dimension of the prediction layer (usually between 4 and 64) +# n_a : int +# Dimension of the attention layer (usually between 4 and 64) +# n_steps : int +# Number of successive steps in the network (usually between 3 and 10) +# gamma : float +# Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) +# n_independent : int +# Number of independent GLU layer in each GLU block (default 2) +# n_shared : int +# Number of independent GLU layer in each GLU block (default 2) +# epsilon : float +# Avoid log(0), this should be kept very low +# virtual_batch_size : int +# Batch size for Ghost Batch Normalization +# momentum : float +# Float value between 0 and 1 which will be used for momentum in all batch norm +# mask_type : str +# Either "sparsemax" or "entmax" : this is the masking function to use +# group_attention_matrix : torch matrix +# Matrix of size (n_groups, input_dim), m_ij = importance within group i of feature j +# """ +# super(TabNet, self).__init__() + +# if group_attention_matrix is None: +# group_attention_matrix = torch.Tensor([]) + +# self.input_dim = input_dim +# self.output_dim = output_dim +# self.n_d = n_d +# self.n_a = n_a +# self.n_steps = n_steps +# self.gamma = gamma +# self.epsilon = epsilon +# self.n_independent = n_independent +# self.n_shared = n_shared +# self.mask_type = mask_type + +# self.virtual_batch_size = virtual_batch_size +# self.post_embed_dim = self.embedder.post_embed_dim + +# self.tabnet = TabNetNoEmbeddings( +# self.post_embed_dim, +# output_dim, +# n_d, +# n_a, +# n_steps, +# gamma, +# n_independent, +# n_shared, +# epsilon, +# virtual_batch_size, +# momentum, +# mask_type, +# self.embedder.embedding_group_matrix, +# ) + +# def forward(self, x): +# x = self.embedder(x) +# return self.tabnet(x) + +# def forward_masks(self, x): +# x = self.embedder(x) +# return self.tabnet.forward_masks(x) -# credits to Yandex https://github.com/Qwicen/node/blob/master/lib/nn_utils.py -def _make_ix_like(input, dim=0): - d = input.size(dim) - rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) - view = [1] * input.dim() - view[0] = -1 - return rho.view(view).transpose(0, dim) - - -class SparsemaxFunction(Function): - """ - An implementation of sparsemax (Martins & Astudillo, 2016). See - :cite:`DBLP:journals/corr/MartinsA16` for detailed description. - By Ben Peters and Vlad Niculae - """ - - @staticmethod - def forward(ctx, input, dim=-1): - """sparsemax: normalizing sparse transform (a la softmax) - - Parameters - ---------- - ctx : torch.autograd.function._ContextMethodMixin - input : torch.Tensor - any shape - dim : int - dimension along which to apply sparsemax - - Returns - ------- - output : torch.Tensor - same shape as input +class TabNet(torch.nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + n_d: int = 8, + n_a: int = 8, + n_steps: int = 3, + gamma: float = 1.3, + n_independent: int = 2, + n_shared: int = 2, + epsilon: float = 1e-15, + virtual_batch_size: int = 128, + momentum: float = 0.02, + mask_type: str = "sparsemax", + group_attention_matrix: Optional[torch.Tensor] = None, + ) -> None: """ - ctx.dim = dim - max_val, _ = input.max(dim=dim, keepdim=True) - input -= max_val # same numerical stability trick as for softmax - tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) - output = torch.clamp(input - tau, min=0) - ctx.save_for_backward(supp_size, output) - return output - - @staticmethod - def backward(ctx, grad_output): - supp_size, output = ctx.saved_tensors - dim = ctx.dim - grad_input = grad_output.clone() - grad_input[output == 0] = 0 - - v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() - v_hat = v_hat.unsqueeze(dim) - grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) - return grad_input, None - - @staticmethod - def _threshold_and_support(input, dim=-1): - """Sparsemax building block: compute the threshold + Defines main part of the TabNet network without the embedding layers. Parameters ---------- - input: torch.Tensor - any dimension - dim : int - dimension along which to apply the sparsemax + input_dim : int + Number of features + output_dim : int or list of int for multi task classification + Dimension of network output + examples : one for regression, 2 for binary classification etc... + n_d : int + Dimension of the prediction layer (usually between 4 and 64) + n_a : int + Dimension of the attention layer (usually between 4 and 64) + n_steps : int + Number of successive steps in the network (usually between 3 and 10) + gamma : float + Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) + n_independent : int + Number of independent GLU layer in each GLU block (default 2) + n_shared : int + Number of independent GLU layer in each GLU block (default 2) + epsilon : float + Avoid log(0), this should be kept very low + virtual_batch_size : int + Batch size for Ghost Batch Normalization + momentum : float + Float value between 0 and 1 which will be used for momentum in all batch norm + mask_type : str + Either "sparsemax" or "entmax" : this is the masking function to use + group_attention_matrix : torch matrix + Matrix of size (n_groups, input_dim), m_ij = importance within group i of feature j + """ - Returns - ------- - tau : torch.Tensor - the threshold value - support_size : torch.Tensor + if n_steps <= 0: + raise ValueError("n_steps should be a positive integer.") + if n_independent == 0 and n_shared == 0: + raise ValueError("n_shared and n_independent can't be both zero.") - """ + super(TabNet, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.n_d = n_d + self.n_a = n_a + self.n_steps = n_steps + self.gamma = gamma + self.epsilon = epsilon + self.n_independent = n_independent + self.n_shared = n_shared + self.virtual_batch_size = virtual_batch_size + self.mask_type = mask_type + self.initial_bn = BatchNorm1d(self.input_dim, momentum=0.01) - input_srt, _ = torch.sort(input, descending=True, dim=dim) - input_cumsum = input_srt.cumsum(dim) - 1 - rhos = _make_ix_like(input, dim) - support = rhos * input_srt > input_cumsum + self.encoder = TabNetEncoder( + input_dim=input_dim, + output_dim=output_dim, + n_d=n_d, + n_a=n_a, + n_steps=n_steps, + gamma=gamma, + n_independent=n_independent, + n_shared=n_shared, + epsilon=epsilon, + virtual_batch_size=virtual_batch_size, + momentum=momentum, + mask_type=mask_type, + group_attention_matrix=group_attention_matrix, + ) - support_size = support.sum(dim=dim).unsqueeze(dim) - tau = input_cumsum.gather(dim, support_size - 1) - tau /= support_size.to(input.dtype) - return tau, support_size + self.final_mapping = Linear(n_d, output_dim, bias=False) + initialize_non_glu(self.final_mapping, n_d, output_dim) + def forward(self, x: torch.Tensor) -> torch.Tensor: + steps_output, M_loss = self.encoder(x) + self.M_loss = M_loss + res = torch.sum(torch.stack(steps_output, dim=0), dim=0) + return self.final_mapping(res) -sparsemax = SparsemaxFunction.apply + def forward_masks(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder.forward_masks(x) -def initialize_non_glu(module, input_dim, output_dim): +def initialize_non_glu(module: Linear, input_dim: int, output_dim: int) -> None: gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(4 * input_dim)) torch.nn.init.xavier_normal_(module.weight, gain=gain_value) # torch.nn.init.zeros_(module.bias) return -def initialize_glu(module, input_dim, output_dim): +def initialize_glu(module: Linear, input_dim: int, output_dim: int) -> None: gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(input_dim)) torch.nn.init.xavier_normal_(module.weight, gain=gain_value) # torch.nn.init.zeros_(module.bias) @@ -112,14 +222,16 @@ class GBN(torch.nn.Module): https://arxiv.org/abs/1705.08741 """ - def __init__(self, input_dim, virtual_batch_size=128, momentum=0.01): + def __init__( + self, input_dim: int, virtual_batch_size: int = 128, momentum: float = 0.01 + ) -> None: super(GBN, self).__init__() self.input_dim = input_dim self.virtual_batch_size = virtual_batch_size self.bn = BatchNorm1d(self.input_dim, momentum=momentum) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0) res = [self.bn(x_) for x_ in chunks] @@ -129,20 +241,20 @@ def forward(self, x): class TabNetEncoder(torch.nn.Module): def __init__( self, - input_dim, - output_dim, - n_d=8, - n_a=8, - n_steps=3, - gamma=1.3, - n_independent=2, - n_shared=2, - epsilon=1e-15, - virtual_batch_size=128, - momentum=0.02, - mask_type="sparsemax", - group_attention_matrix=None, - ): + input_dim: int, + output_dim: int, + n_d: int = 8, + n_a: int = 8, + n_steps: int = 3, + gamma: float = 1.3, + n_independent: int = 2, + n_shared: int = 2, + epsilon: float = 1e-15, + virtual_batch_size: int = 128, + momentum: float = 0.02, + mask_type: str = "sparsemax", + group_attention_matrix: Optional[torch.Tensor] = None, + ) -> None: """ Defines main part of the TabNet network without the embedding layers. @@ -173,8 +285,6 @@ def __init__( Float value between 0 and 1 which will be used for momentum in all batch norm mask_type : str Either "sparsemax" or "entmax" : this is the masking function to use - group_attention_matrix : torch matrix - Matrix of size (n_groups, input_dim), m_ij = importance within group i of feature j """ super(TabNetEncoder, self).__init__() self.input_dim = input_dim @@ -210,7 +320,6 @@ def __init__( shared_feat_transform.append( Linear(n_d + n_a, 2 * (n_d + n_a), bias=False) ) - else: shared_feat_transform = None @@ -238,7 +347,6 @@ def __init__( attention = AttentiveTransformer( n_a, self.attention_dim, - group_matrix=group_attention_matrix, virtual_batch_size=self.virtual_batch_size, momentum=momentum, mask_type=self.mask_type, @@ -246,14 +354,16 @@ def __init__( self.feat_transformers.append(transformer) self.att_transformers.append(attention) - def forward(self, x, prior=None): + def forward( + self, x: torch.Tensor, prior: Optional[torch.Tensor] = None + ) -> Tuple[List[torch.Tensor], torch.Tensor]: x = self.initial_bn(x) bs = x.shape[0] # batch size if prior is None: prior = torch.ones((bs, self.attention_dim)).to(x.device) - M_loss = 0 + M_loss = 0.0 att = self.initial_splitter(x)[:, self.n_d :] steps_output = [] for step in range(self.n_steps): @@ -275,18 +385,18 @@ def forward(self, x, prior=None): M_loss /= self.n_steps return steps_output, M_loss - def forward_masks(self, x): + def forward_masks(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: x = self.initial_bn(x) bs = x.shape[0] # batch size prior = torch.ones((bs, self.attention_dim)).to(x.device) M_explain = torch.zeros(x.shape).to(x.device) att = self.initial_splitter(x)[:, self.n_d :] - masks = {} + masks = [] for step in range(self.n_steps): M = self.att_transformers[step](prior, att) M_feature_level = torch.matmul(M, self.group_attention_matrix) - masks[step] = M_feature_level + masks.append(M_feature_level) # update prior prior = torch.mul(self.gamma - M, prior) # output @@ -302,423 +412,200 @@ def forward_masks(self, x): return M_explain, masks -class TabNetDecoder(torch.nn.Module): - def __init__( - self, - input_dim, - n_d=8, - n_steps=3, - n_independent=1, - n_shared=1, - virtual_batch_size=128, - momentum=0.02, - ): - """ - Defines main part of the TabNet network without the embedding layers. - - Parameters - ---------- - input_dim : int - Number of features - output_dim : int or list of int for multi task classification - Dimension of network output - examples : one for regression, 2 for binary classification etc... - n_d : int - Dimension of the prediction layer (usually between 4 and 64) - n_steps : int - Number of successive steps in the network (usually between 3 and 10) - gamma : float - Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) - n_independent : int - Number of independent GLU layer in each GLU block (default 1) - n_shared : int - Number of independent GLU layer in each GLU block (default 1) - virtual_batch_size : int - Batch size for Ghost Batch Normalization - momentum : float - Float value between 0 and 1 which will be used for momentum in all batch norm - """ - super(TabNetDecoder, self).__init__() - self.input_dim = input_dim - self.n_d = n_d - self.n_steps = n_steps - self.n_independent = n_independent - self.n_shared = n_shared - self.virtual_batch_size = virtual_batch_size - - self.feat_transformers = torch.nn.ModuleList() - - if self.n_shared > 0: - shared_feat_transform = torch.nn.ModuleList() - for i in range(self.n_shared): - if i == 0: - shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) - else: - shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) - - else: - shared_feat_transform = None - - for step in range(n_steps): - transformer = FeatTransformer( - n_d, - n_d, - shared_feat_transform, - n_glu_independent=self.n_independent, - virtual_batch_size=self.virtual_batch_size, - momentum=momentum, - ) - self.feat_transformers.append(transformer) - - self.reconstruction_layer = Linear(n_d, self.input_dim, bias=False) - initialize_non_glu(self.reconstruction_layer, n_d, self.input_dim) - - def forward(self, steps_output): - res = 0 - for step_nb, step_output in enumerate(steps_output): - x = self.feat_transformers[step_nb](step_output) - res = torch.add(res, x) - res = self.reconstruction_layer(res) - return res - - -class TabNetPretraining(torch.nn.Module): - def __init__( - self, - input_dim, - pretraining_ratio=0.2, - n_d=8, - n_a=8, - n_steps=3, - gamma=1.3, - cat_idxs=[], - cat_dims=[], - cat_emb_dim=1, - n_independent=2, - n_shared=2, - epsilon=1e-15, - virtual_batch_size=128, - momentum=0.02, - mask_type="sparsemax", - n_shared_decoder=1, - n_indep_decoder=1, - group_attention_matrix=None, - ): - super(TabNetPretraining, self).__init__() - - self.cat_idxs = cat_idxs or [] - self.cat_dims = cat_dims or [] - self.cat_emb_dim = cat_emb_dim - - self.input_dim = input_dim - self.n_d = n_d - self.n_a = n_a - self.n_steps = n_steps - self.gamma = gamma - self.epsilon = epsilon - self.n_independent = n_independent - self.n_shared = n_shared - self.mask_type = mask_type - self.pretraining_ratio = pretraining_ratio - self.n_shared_decoder = n_shared_decoder - self.n_indep_decoder = n_indep_decoder - - if self.n_steps <= 0: - raise ValueError("n_steps should be a positive integer.") - if self.n_independent == 0 and self.n_shared == 0: - raise ValueError("n_shared and n_independent can't be both zero.") - - self.virtual_batch_size = virtual_batch_size - self.embedder = EmbeddingGenerator( - input_dim, cat_dims, cat_idxs, cat_emb_dim, group_attention_matrix - ) - self.post_embed_dim = self.embedder.post_embed_dim - - self.masker = RandomObfuscator( - self.pretraining_ratio, group_matrix=self.embedder.embedding_group_matrix - ) - self.encoder = TabNetEncoder( - input_dim=self.post_embed_dim, - output_dim=self.post_embed_dim, - n_d=n_d, - n_a=n_a, - n_steps=n_steps, - gamma=gamma, - n_independent=n_independent, - n_shared=n_shared, - epsilon=epsilon, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - mask_type=mask_type, - group_attention_matrix=self.embedder.embedding_group_matrix, - ) - self.decoder = TabNetDecoder( - self.post_embed_dim, - n_d=n_d, - n_steps=n_steps, - n_independent=self.n_indep_decoder, - n_shared=self.n_shared_decoder, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - ) - - def forward(self, x): - """ - Returns: res, embedded_x, obf_vars - res : output of reconstruction - embedded_x : embedded input - obf_vars : which variable where obfuscated - """ - embedded_x = self.embedder(x) - if self.training: - masked_x, obfuscated_groups, obfuscated_vars = self.masker(embedded_x) - # set prior of encoder with obfuscated groups - prior = 1 - obfuscated_groups - steps_out, _ = self.encoder(masked_x, prior=prior) - res = self.decoder(steps_out) - return res, embedded_x, obfuscated_vars - else: - steps_out, _ = self.encoder(embedded_x) - res = self.decoder(steps_out) - return res, embedded_x, torch.ones(embedded_x.shape).to(x.device) - - def forward_masks(self, x): - embedded_x = self.embedder(x) - return self.encoder.forward_masks(embedded_x) - - -class TabNetNoEmbeddings(torch.nn.Module): - def __init__( - self, - input_dim, - output_dim, - n_d=8, - n_a=8, - n_steps=3, - gamma=1.3, - n_independent=2, - n_shared=2, - epsilon=1e-15, - virtual_batch_size=128, - momentum=0.02, - mask_type="sparsemax", - group_attention_matrix=None, - ): - """ - Defines main part of the TabNet network without the embedding layers. - - Parameters - ---------- - input_dim : int - Number of features - output_dim : int or list of int for multi task classification - Dimension of network output - examples : one for regression, 2 for binary classification etc... - n_d : int - Dimension of the prediction layer (usually between 4 and 64) - n_a : int - Dimension of the attention layer (usually between 4 and 64) - n_steps : int - Number of successive steps in the network (usually between 3 and 10) - gamma : float - Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) - n_independent : int - Number of independent GLU layer in each GLU block (default 2) - n_shared : int - Number of independent GLU layer in each GLU block (default 2) - epsilon : float - Avoid log(0), this should be kept very low - virtual_batch_size : int - Batch size for Ghost Batch Normalization - momentum : float - Float value between 0 and 1 which will be used for momentum in all batch norm - mask_type : str - Either "sparsemax" or "entmax" : this is the masking function to use - group_attention_matrix : torch matrix - Matrix of size (n_groups, input_dim), m_ij = importance within group i of feature j - """ - super(TabNetNoEmbeddings, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.is_multi_task = isinstance(output_dim, list) - self.n_d = n_d - self.n_a = n_a - self.n_steps = n_steps - self.gamma = gamma - self.epsilon = epsilon - self.n_independent = n_independent - self.n_shared = n_shared - self.virtual_batch_size = virtual_batch_size - self.mask_type = mask_type - self.initial_bn = BatchNorm1d(self.input_dim, momentum=0.01) - - self.encoder = TabNetEncoder( - input_dim=input_dim, - output_dim=output_dim, - n_d=n_d, - n_a=n_a, - n_steps=n_steps, - gamma=gamma, - n_independent=n_independent, - n_shared=n_shared, - epsilon=epsilon, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - mask_type=mask_type, - group_attention_matrix=group_attention_matrix, - ) - - if self.is_multi_task: - self.multi_task_mappings = torch.nn.ModuleList() - for task_dim in output_dim: - task_mapping = Linear(n_d, task_dim, bias=False) - initialize_non_glu(task_mapping, n_d, task_dim) - self.multi_task_mappings.append(task_mapping) - else: - self.final_mapping = Linear(n_d, output_dim, bias=False) - initialize_non_glu(self.final_mapping, n_d, output_dim) - - def forward(self, x): - res = 0 - steps_output, M_loss = self.encoder(x) - res = torch.sum(torch.stack(steps_output, dim=0), dim=0) - - if self.is_multi_task: - # Result will be in list format - out = [] - for task_mapping in self.multi_task_mappings: - out.append(task_mapping(res)) - else: - out = self.final_mapping(res) - return out, M_loss - - def forward_masks(self, x): - return self.encoder.forward_masks(x) - - -class TabNet(torch.nn.Module): - def __init__( - self, - input_dim, - output_dim, - n_d=8, - n_a=8, - n_steps=3, - gamma=1.3, - cat_idxs=[], - cat_dims=[], - cat_emb_dim=1, - n_independent=2, - n_shared=2, - epsilon=1e-15, - virtual_batch_size=128, - momentum=0.02, - mask_type="sparsemax", - group_attention_matrix=[], - ): - """ - Defines TabNet network - - Parameters - ---------- - input_dim : int - Initial number of features - output_dim : int - Dimension of network output - examples : one for regression, 2 for binary classification etc... - n_d : int - Dimension of the prediction layer (usually between 4 and 64) - n_a : int - Dimension of the attention layer (usually between 4 and 64) - n_steps : int - Number of successive steps in the network (usually between 3 and 10) - gamma : float - Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) - cat_idxs : list of int - Index of each categorical column in the dataset - cat_dims : list of int - Number of categories in each categorical column - cat_emb_dim : int or list of int - Size of the embedding of categorical features - if int, all categorical features will have same embedding size - if list of int, every corresponding feature will have specific size - n_independent : int - Number of independent GLU layer in each GLU block (default 2) - n_shared : int - Number of independent GLU layer in each GLU block (default 2) - epsilon : float - Avoid log(0), this should be kept very low - virtual_batch_size : int - Batch size for Ghost Batch Normalization - momentum : float - Float value between 0 and 1 which will be used for momentum in all batch norm - mask_type : str - Either "sparsemax" or "entmax" : this is the masking function to use - group_attention_matrix : torch matrix - Matrix of size (n_groups, input_dim), m_ij = importance within group i of feature j - """ - super(TabNet, self).__init__() - self.cat_idxs = cat_idxs or [] - self.cat_dims = cat_dims or [] - self.cat_emb_dim = cat_emb_dim - - self.input_dim = input_dim - self.output_dim = output_dim - self.n_d = n_d - self.n_a = n_a - self.n_steps = n_steps - self.gamma = gamma - self.epsilon = epsilon - self.n_independent = n_independent - self.n_shared = n_shared - self.mask_type = mask_type - - if self.n_steps <= 0: - raise ValueError("n_steps should be a positive integer.") - if self.n_independent == 0 and self.n_shared == 0: - raise ValueError("n_shared and n_independent can't be both zero.") - - self.virtual_batch_size = virtual_batch_size - self.embedder = EmbeddingGenerator( - input_dim, cat_dims, cat_idxs, cat_emb_dim, group_attention_matrix - ) - self.post_embed_dim = self.embedder.post_embed_dim - - self.tabnet = TabNetNoEmbeddings( - self.post_embed_dim, - output_dim, - n_d, - n_a, - n_steps, - gamma, - n_independent, - n_shared, - epsilon, - virtual_batch_size, - momentum, - mask_type, - self.embedder.embedding_group_matrix, - ) - - def forward(self, x): - x = self.embedder(x) - return self.tabnet(x) - - def forward_masks(self, x): - x = self.embedder(x) - return self.tabnet.forward_masks(x) +# class TabNetDecoder(torch.nn.Module): +# def __init__( +# self, +# input_dim, +# n_d=8, +# n_steps=3, +# n_independent=1, +# n_shared=1, +# virtual_batch_size=128, +# momentum=0.02, +# ): +# """ +# Defines main part of the TabNet network without the embedding layers. + +# Parameters +# ---------- +# input_dim : int +# Number of features +# output_dim : int or list of int for multi task classification +# Dimension of network output +# examples : one for regression, 2 for binary classification etc... +# n_d : int +# Dimension of the prediction layer (usually between 4 and 64) +# n_steps : int +# Number of successive steps in the network (usually between 3 and 10) +# gamma : float +# Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) +# n_independent : int +# Number of independent GLU layer in each GLU block (default 1) +# n_shared : int +# Number of independent GLU layer in each GLU block (default 1) +# virtual_batch_size : int +# Batch size for Ghost Batch Normalization +# momentum : float +# Float value between 0 and 1 which will be used for momentum in all batch norm +# """ +# super(TabNetDecoder, self).__init__() +# self.input_dim = input_dim +# self.n_d = n_d +# self.n_steps = n_steps +# self.n_independent = n_independent +# self.n_shared = n_shared +# self.virtual_batch_size = virtual_batch_size + +# self.feat_transformers = torch.nn.ModuleList() + +# if self.n_shared > 0: +# shared_feat_transform = torch.nn.ModuleList() +# for i in range(self.n_shared): +# if i == 0: +# shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) +# else: +# shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) + +# else: +# shared_feat_transform = None + +# for step in range(n_steps): +# transformer = FeatTransformer( +# n_d, +# n_d, +# shared_feat_transform, +# n_glu_independent=self.n_independent, +# virtual_batch_size=self.virtual_batch_size, +# momentum=momentum, +# ) +# self.feat_transformers.append(transformer) + +# self.reconstruction_layer = Linear(n_d, self.input_dim, bias=False) +# initialize_non_glu(self.reconstruction_layer, n_d, self.input_dim) + +# def forward(self, steps_output): +# res = 0 +# for step_nb, step_output in enumerate(steps_output): +# x = self.feat_transformers[step_nb](step_output) +# res = torch.add(res, x) +# res = self.reconstruction_layer(res) +# return res + + +# class TabNetPretraining(torch.nn.Module): +# def __init__( +# self, +# input_dim, +# pretraining_ratio=0.2, +# n_d=8, +# n_a=8, +# n_steps=3, +# gamma=1.3, +# cat_idxs=[], +# cat_dims=[], +# cat_emb_dim=1, +# n_independent=2, +# n_shared=2, +# epsilon=1e-15, +# virtual_batch_size=128, +# momentum=0.02, +# mask_type="sparsemax", +# n_shared_decoder=1, +# n_indep_decoder=1, +# group_attention_matrix=None, +# ): +# super(TabNetPretraining, self).__init__() + +# self.cat_idxs = cat_idxs or [] +# self.cat_dims = cat_dims or [] +# self.cat_emb_dim = cat_emb_dim + +# self.input_dim = input_dim +# self.n_d = n_d +# self.n_a = n_a +# self.n_steps = n_steps +# self.gamma = gamma +# self.epsilon = epsilon +# self.n_independent = n_independent +# self.n_shared = n_shared +# self.mask_type = mask_type +# self.pretraining_ratio = pretraining_ratio +# self.n_shared_decoder = n_shared_decoder +# self.n_indep_decoder = n_indep_decoder + +# if self.n_steps <= 0: +# raise ValueError("n_steps should be a positive integer.") +# if self.n_independent == 0 and self.n_shared == 0: +# raise ValueError("n_shared and n_independent can't be both zero.") + +# self.virtual_batch_size = virtual_batch_size +# self.embedder = EmbeddingGenerator( +# input_dim, cat_dims, cat_idxs, cat_emb_dim, group_attention_matrix +# ) +# self.post_embed_dim = self.embedder.post_embed_dim + +# self.masker = RandomObfuscator( +# self.pretraining_ratio, group_matrix=self.embedder.embedding_group_matrix +# ) +# self.encoder = TabNetEncoder( +# input_dim=self.post_embed_dim, +# output_dim=self.post_embed_dim, +# n_d=n_d, +# n_a=n_a, +# n_steps=n_steps, +# gamma=gamma, +# n_independent=n_independent, +# n_shared=n_shared, +# epsilon=epsilon, +# virtual_batch_size=virtual_batch_size, +# momentum=momentum, +# mask_type=mask_type, +# group_attention_matrix=self.embedder.embedding_group_matrix, +# ) +# self.decoder = TabNetDecoder( +# self.post_embed_dim, +# n_d=n_d, +# n_steps=n_steps, +# n_independent=self.n_indep_decoder, +# n_shared=self.n_shared_decoder, +# virtual_batch_size=virtual_batch_size, +# momentum=momentum, +# ) + +# def forward(self, x): +# """ +# Returns: res, embedded_x, obf_vars +# res : output of reconstruction +# embedded_x : embedded input +# obf_vars : which variable where obfuscated +# """ +# embedded_x = self.embedder(x) +# if self.training: +# masked_x, obfuscated_groups, obfuscated_vars = self.masker(embedded_x) +# # set prior of encoder with obfuscated groups +# prior = 1 - obfuscated_groups +# steps_out, _ = self.encoder(masked_x, prior=prior) +# res = self.decoder(steps_out) +# return res, embedded_x, obfuscated_vars +# else: +# steps_out, _ = self.encoder(embedded_x) +# res = self.decoder(steps_out) +# return res, embedded_x, torch.ones(embedded_x.shape).to(x.device) + +# def forward_masks(self, x): +# embedded_x = self.embedder(x) +# return self.encoder.forward_masks(embedded_x) class AttentiveTransformer(torch.nn.Module): def __init__( self, - input_dim, - group_dim, - group_matrix, - virtual_batch_size=128, - momentum=0.02, - mask_type="sparsemax", - ): + input_dim: int, + group_dim: int, + virtual_batch_size: int = 128, + momentum: float = 0.02, + mask_type: str = "sparsemax", + ) -> None: """ Initialize an attention transformer. @@ -743,17 +630,17 @@ def __init__( ) if mask_type == "sparsemax": - # Sparsemax - self.selector = sparsemax.Sparsemax(dim=-1) + self.selector = Sparsemax() elif mask_type == "entmax": - # Entmax - self.selector = sparsemax.Entmax15(dim=-1) + self.selector = Entmax() else: raise NotImplementedError( - "Please choose either sparsemax" + "or entmax as masktype" + "Please choose either sparsemax or entmax as masktype" ) - def forward(self, priors, processed_feat): + def forward( + self, priors: torch.Tensor, processed_feat: torch.Tensor + ) -> torch.Tensor: x = self.fc(processed_feat) x = self.bn(x) x = torch.mul(x, priors) @@ -764,13 +651,13 @@ def forward(self, priors, processed_feat): class FeatTransformer(torch.nn.Module): def __init__( self, - input_dim, - output_dim, - shared_layers, - n_glu_independent, - virtual_batch_size=128, - momentum=0.02, - ): + input_dim: int, + output_dim: int, + shared_layers: torch.nn.ModuleList, + n_glu_independent: int, + virtual_batch_size: int = 128, + momentum: float = 0.02, + ) -> None: super(FeatTransformer, self).__init__() """ Initialize a feature transformer. @@ -819,10 +706,10 @@ def __init__( else: spec_input_dim = input_dim if is_first else output_dim self.specifics = GLU_Block( - spec_input_dim, output_dim, first=is_first, **params + spec_input_dim, output_dim, first=is_first, **params # type: ignore ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.shared(x) x = self.specifics(x) return x @@ -835,14 +722,15 @@ class GLU_Block(torch.nn.Module): def __init__( self, - input_dim, - output_dim, - n_glu=2, - first=False, - shared_layers=None, - virtual_batch_size=128, - momentum=0.02, - ): + input_dim: int, + output_dim: int, + *, + n_glu: int = 2, + first: bool = False, + shared_layers: Optional[torch.nn.ModuleList] = None, + virtual_batch_size: int = 128, + momentum: float = 0.02, + ) -> None: super(GLU_Block, self).__init__() self.first = first self.shared_layers = shared_layers @@ -852,12 +740,12 @@ def __init__( params = {"virtual_batch_size": virtual_batch_size, "momentum": momentum} fc = shared_layers[0] if shared_layers else None - self.glu_layers.append(GLU_Layer(input_dim, output_dim, fc=fc, **params)) + self.glu_layers.append(GLU_Layer(input_dim, output_dim, fc=fc, **params)) # type: ignore for glu_id in range(1, self.n_glu): fc = shared_layers[glu_id] if shared_layers else None - self.glu_layers.append(GLU_Layer(output_dim, output_dim, fc=fc, **params)) + self.glu_layers.append(GLU_Layer(output_dim, output_dim, fc=fc, **params)) # type: ignore - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: scale = torch.sqrt(torch.FloatTensor([0.5]).to(x.device)) if self.first: # the first layer of the block has no scale multiplication x = self.glu_layers[0](x) @@ -873,8 +761,14 @@ def forward(self, x): class GLU_Layer(torch.nn.Module): def __init__( - self, input_dim, output_dim, fc=None, virtual_batch_size=128, momentum=0.02 - ): + self, + input_dim: int, + output_dim: int, + *, + fc: Optional[Linear] = None, + virtual_batch_size: int = 128, + momentum: float = 0.02, + ) -> None: super(GLU_Layer, self).__init__() self.output_dim = output_dim @@ -888,143 +782,143 @@ def __init__( 2 * output_dim, virtual_batch_size=virtual_batch_size, momentum=momentum ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc(x) x = self.bn(x) out = torch.mul(x[:, : self.output_dim], torch.sigmoid(x[:, self.output_dim :])) return out -class EmbeddingGenerator(torch.nn.Module): - """ - Classical embeddings generator - """ - - def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dims, group_matrix): - """This is an embedding module for an entire set of features - - Parameters - ---------- - input_dim : int - Number of features coming as input (number of columns) - cat_dims : list of int - Number of modalities for each categorial features - If the list is empty, no embeddings will be done - cat_idxs : list of int - Positional index for each categorical features in inputs - cat_emb_dim : list of int - Embedding dimension for each categorical features - If int, the same embedding dimension will be used for all categorical features - group_matrix : torch matrix - Original group matrix before embeddings - """ - super(EmbeddingGenerator, self).__init__() - - if cat_dims == [] and cat_idxs == []: - self.skip_embedding = True - self.post_embed_dim = input_dim - self.embedding_group_matrix = group_matrix.to(group_matrix.device) - return - else: - self.skip_embedding = False - - self.post_embed_dim = int(input_dim + np.sum(cat_emb_dims) - len(cat_emb_dims)) - - self.embeddings = torch.nn.ModuleList() - - for cat_dim, emb_dim in zip(cat_dims, cat_emb_dims): - self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim)) - - # record continuous indices - self.continuous_idx = torch.ones(input_dim, dtype=torch.bool) - self.continuous_idx[cat_idxs] = 0 - - # update group matrix - n_groups = group_matrix.shape[0] - self.embedding_group_matrix = torch.empty( - (n_groups, self.post_embed_dim), device=group_matrix.device - ) - for group_idx in range(n_groups): - post_emb_idx = 0 - cat_feat_counter = 0 - for init_feat_idx in range(input_dim): - if self.continuous_idx[init_feat_idx] == 1: - # this means that no embedding is applied to this column - self.embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[ - group_idx, init_feat_idx - ] # noqa - post_emb_idx += 1 - else: - # this is a categorical feature which creates multiple embeddings - n_embeddings = cat_emb_dims[cat_feat_counter] - self.embedding_group_matrix[ - group_idx, post_emb_idx : post_emb_idx + n_embeddings - ] = ( - group_matrix[group_idx, init_feat_idx] / n_embeddings - ) # noqa - post_emb_idx += n_embeddings - cat_feat_counter += 1 - - def forward(self, x): - """ - Apply embeddings to inputs - Inputs should be (batch_size, input_dim) - Outputs will be of size (batch_size, self.post_embed_dim) - """ - if self.skip_embedding: - # no embeddings required - return x - - cols = [] - cat_feat_counter = 0 - for feat_init_idx, is_continuous in enumerate(self.continuous_idx): - # Enumerate through continuous idx boolean mask to apply embeddings - if is_continuous: - cols.append(x[:, feat_init_idx].float().view(-1, 1)) - else: - cols.append( - self.embeddings[cat_feat_counter](x[:, feat_init_idx].long()) - ) - cat_feat_counter += 1 - # concat - post_embeddings = torch.cat(cols, dim=1) - return post_embeddings - - -class RandomObfuscator(torch.nn.Module): - """ - Create and applies obfuscation masks. - The obfuscation is done at group level to match attention. - """ - - def __init__(self, pretraining_ratio, group_matrix): - """ - This create random obfuscation for self suppervised pretraining - Parameters - ---------- - pretraining_ratio : float - Ratio of feature to randomly discard for reconstruction - - """ - super(RandomObfuscator, self).__init__() - self.pretraining_ratio = pretraining_ratio - # group matrix is set to boolean here to pass all posssible information - self.group_matrix = (group_matrix > 0) + 0.0 - self.num_groups = group_matrix.shape[0] - - def forward(self, x): - """ - Generate random obfuscation mask. - - Returns - ------- - masked input and obfuscated variables. - """ - bs = x.shape[0] - - obfuscated_groups = torch.bernoulli( - self.pretraining_ratio * torch.ones((bs, self.num_groups), device=x.device) - ) - obfuscated_vars = torch.matmul(obfuscated_groups, self.group_matrix) - masked_input = torch.mul(1 - obfuscated_vars, x) - return masked_input, obfuscated_groups, obfuscated_vars +# class EmbeddingGenerator(torch.nn.Module): +# """ +# Classical embeddings generator +# """ + +# def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dims, group_matrix): +# """This is an embedding module for an entire set of features + +# Parameters +# ---------- +# input_dim : int +# Number of features coming as input (number of columns) +# cat_dims : list of int +# Number of modalities for each categorial features +# If the list is empty, no embeddings will be done +# cat_idxs : list of int +# Positional index for each categorical features in inputs +# cat_emb_dim : list of int +# Embedding dimension for each categorical features +# If int, the same embedding dimension will be used for all categorical features +# group_matrix : torch matrix +# Original group matrix before embeddings +# """ +# super(EmbeddingGenerator, self).__init__() + +# if cat_dims == [] and cat_idxs == []: +# self.skip_embedding = True +# self.post_embed_dim = input_dim +# self.embedding_group_matrix = group_matrix.to(group_matrix.device) +# return +# else: +# self.skip_embedding = False + +# self.post_embed_dim = int(input_dim + np.sum(cat_emb_dims) - len(cat_emb_dims)) + +# self.embeddings = torch.nn.ModuleList() + +# for cat_dim, emb_dim in zip(cat_dims, cat_emb_dims): +# self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim)) + +# # record continuous indices +# self.continuous_idx = torch.ones(input_dim, dtype=torch.bool) +# self.continuous_idx[cat_idxs] = 0 + +# # update group matrix +# n_groups = group_matrix.shape[0] +# self.embedding_group_matrix = torch.empty( +# (n_groups, self.post_embed_dim), device=group_matrix.device +# ) +# for group_idx in range(n_groups): +# post_emb_idx = 0 +# cat_feat_counter = 0 +# for init_feat_idx in range(input_dim): +# if self.continuous_idx[init_feat_idx] == 1: +# # this means that no embedding is applied to this column +# self.embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[ +# group_idx, init_feat_idx +# ] # noqa +# post_emb_idx += 1 +# else: +# # this is a categorical feature which creates multiple embeddings +# n_embeddings = cat_emb_dims[cat_feat_counter] +# self.embedding_group_matrix[ +# group_idx, post_emb_idx : post_emb_idx + n_embeddings +# ] = ( +# group_matrix[group_idx, init_feat_idx] / n_embeddings +# ) # noqa +# post_emb_idx += n_embeddings +# cat_feat_counter += 1 + +# def forward(self, x): +# """ +# Apply embeddings to inputs +# Inputs should be (batch_size, input_dim) +# Outputs will be of size (batch_size, self.post_embed_dim) +# """ +# if self.skip_embedding: +# # no embeddings required +# return x + +# cols = [] +# cat_feat_counter = 0 +# for feat_init_idx, is_continuous in enumerate(self.continuous_idx): +# # Enumerate through continuous idx boolean mask to apply embeddings +# if is_continuous: +# cols.append(x[:, feat_init_idx].float().view(-1, 1)) +# else: +# cols.append( +# self.embeddings[cat_feat_counter](x[:, feat_init_idx].long()) +# ) +# cat_feat_counter += 1 +# # concat +# post_embeddings = torch.cat(cols, dim=1) +# return post_embeddings + + +# class RandomObfuscator(torch.nn.Module): +# """ +# Create and applies obfuscation masks. +# The obfuscation is done at group level to match attention. +# """ + +# def __init__(self, pretraining_ratio, group_matrix): +# """ +# This create random obfuscation for self suppervised pretraining +# Parameters +# ---------- +# pretraining_ratio : float +# Ratio of feature to randomly discard for reconstruction + +# """ +# super(RandomObfuscator, self).__init__() +# self.pretraining_ratio = pretraining_ratio +# # group matrix is set to boolean here to pass all posssible information +# self.group_matrix = (group_matrix > 0) + 0.0 +# self.num_groups = group_matrix.shape[0] + +# def forward(self, x): +# """ +# Generate random obfuscation mask. + +# Returns +# ------- +# masked input and obfuscated variables. +# """ +# bs = x.shape[0] + +# obfuscated_groups = torch.bernoulli( +# self.pretraining_ratio * torch.ones((bs, self.num_groups), device=x.device) +# ) +# obfuscated_vars = torch.matmul(obfuscated_groups, self.group_matrix) +# masked_input = torch.mul(1 - obfuscated_vars, x) +# return masked_input, obfuscated_groups, obfuscated_vars diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 0332fa03..20a293b4 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -40,7 +40,7 @@ def __init__( print_interval: int = 100, # model params model_type: str = "mlp", - mlp_params: Optional[dict] = None, + model_params: Optional[dict] = None, dim_embed: int = 128, # early stopping n_iter_min: int = 100, @@ -95,13 +95,6 @@ def fit( cat_counts = [0] self.feature_names_out = self.feature_names - model_params = dict( - num_classes=self.n_classes, - conditional=cond is not None, - mlp_params=self.mlp_params, - dim_emb=self.dim_embed, - ) - dataset = TensorDataset( torch.tensor(X.values, dtype=torch.float32, device=self.device), torch.tensor([torch.nan] * len(X), dtype=torch.float32, device=self.device) @@ -117,11 +110,14 @@ def fit( self.diffusion = GaussianMultinomialDiffusion( model_type=self.model_type, - model_params=model_params, + model_params=self.model_params, num_categorical_features=cat_counts, num_numerical_features=X.shape[1] - len(cat_cols), gaussian_loss_type=self.gaussian_loss_type, num_timesteps=self.num_timesteps, + num_classes=self.n_classes, + conditional=cond is not None, + dim_emb=self.dim_embed, scheduler=self.scheduler, device=self.device, ).to(self.device) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 2e29aec3..6c11c5be 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -18,7 +18,7 @@ from synthcity.logger import debug, info, warning # synthcity relative -from .modules import MLPDiffusion, ResNetDiffusion +from .modules import DiffusionModel from .utils import ( discretized_gaussian_log_likelihood, index_to_log_onehot, @@ -67,11 +67,15 @@ def alpha_bar(t: float) -> float: class GaussianMultinomialDiffusion(torch.nn.Module): def __init__( self, + *, num_numerical_features: int, num_categorical_features: tuple, - model_type: str = "mlp", - model_params: Optional[dict] = None, + model_type: str, + model_params: dict, num_timesteps: int = 1000, + num_classes: int = 0, + conditional: bool = False, + dim_emb: int = 128, gaussian_loss_type: str = "mse", gaussian_parametrization: str = "eps", multinomial_loss_type: str = "vb_stochastic", @@ -110,24 +114,14 @@ def __init__( self.slices_for_classes.append(np.arange(offsets[i - 1], offsets[i])) self.offsets = torch.from_numpy(np.append([0], offsets)).to(device).long() - if model_params is None: - model_params = dict( - dim_in=self.dim_input, num_classes=0, conditional=False, mlp_params=None - ) - else: - model_params["dim_in"] = self.dim_input - - if model_params["mlp_params"] is None: - model_params["mlp_params"] = dict( - n_units_hidden=256, n_layers_hidden=3, dropout=0.0 - ) - - if model_type == "mlp": - self.denoise_fn = MLPDiffusion(**model_params) - elif model_type == "resnet": - self.denoise_fn = ResNetDiffusion(**model_params) - else: - raise NotImplementedError(f"unknown model type: {model_type}") + self.denoise_fn = DiffusionModel( + dim_in=self.dim_input, + dim_emb=dim_emb, + num_classes=num_classes, + conditional=conditional, + model_type=model_type, + model_params=model_params, + ) self.gaussian_loss_type = gaussian_loss_type self.gaussian_parametrization = gaussian_parametrization diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py index 8d3a777b..5e49f76e 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/modules.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/modules.py @@ -11,7 +11,7 @@ from torch import Tensor, nn # synthcity absolute -from synthcity.plugins.core.models.mlp import MLP, get_nonlin +from synthcity.plugins.core.models.factory import get_model, get_nonlin class TimeStepEmbedding(nn.Module): @@ -58,15 +58,14 @@ def forward(self, timesteps: Tensor) -> Tensor: return self.fc(emb) -class MLPDiffusion(nn.Module): - add_residual = False - +class DiffusionModel(nn.Module): def __init__( self, dim_in: int, dim_emb: int = 128, *, - mlp_params: dict = {}, + model_type: str = "mlp", + model_params: dict = {}, conditional: bool = False, num_classes: int = 0, emb_nonlin: Union[str, nn.Module] = "silu", @@ -91,13 +90,17 @@ def __init__( elif self.num_classes == 0: # regression self.label_emb = nn.Linear(1, dim_emb) - self.model = MLP( - n_units_in=dim_emb, - n_units_out=dim_in, - task_type="/", - residual=self.add_residual, - **mlp_params, - ) + if not model_params: + model_params = {} # avoid changing the default dict + + if model_type == "mlp": + if not model_params: + model_params = dict(n_units_hidden=256, n_layers_hidden=3, dropout=0.0) + model_params.update(n_units_in=dim_emb, n_units_out=dim_in) + elif model_type == "tabnet": + model_params.update(input_dim=dim_emb, output_dim=dim_in) + + self.model = get_model(model_type, model_params) def forward(self, x: Tensor, t: Tensor, y: Optional[Tensor] = None) -> Tensor: emb = self.time_emb(t) @@ -111,7 +114,3 @@ def forward(self, x: Tensor, t: Tensor, y: Optional[Tensor] = None) -> Tensor: emb += self.emb_nonlin(self.label_emb(y)) x = self.proj(x) + emb return self.model(x) - - -class ResNetDiffusion(MLPDiffusion): - add_residual = True diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 74b72142..ad07a06c 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -17,7 +17,7 @@ from synthcity.utils.serialization import dataframe_hash # synthcity relative -from .data_encoder import get_encoder +from .factory import get_feature_encoder class FeatureInfo(BaseModel): @@ -114,9 +114,13 @@ def _fit_feature(self, feature: pd.Series, feature_type: str) -> FeatureInfo: Information of the fitted feature encoder. """ if feature_type == "discrete": - encoder = get_encoder(self.categorical_encoder)(**self.cat_encoder_params) + encoder = get_feature_encoder( + self.categorical_encoder, self.cat_encoder_params + ) else: - encoder = get_encoder(self.continuous_encoder)(**self.cont_encoder_params) + encoder = get_feature_encoder( + self.continuous_encoder, self.cont_encoder_params + ) encoder.fit(feature) diff --git a/src/synthcity/plugins/core/models/ts_model.py b/src/synthcity/plugins/core/models/ts_model.py index a7026146..34f6ac4c 100644 --- a/src/synthcity/plugins/core/models/ts_model.py +++ b/src/synthcity/plugins/core/models/ts_model.py @@ -20,7 +20,9 @@ # synthcity absolute import synthcity.logger as log -from synthcity.plugins.core.models.mlp import MLP, MultiActivationHead, get_nonlin +from synthcity.plugins.core.models.factory import get_nonlin +from synthcity.plugins.core.models.layers import MultiActivationHead +from synthcity.plugins.core.models.mlp import MLP from synthcity.utils.constants import DEVICE from synthcity.utils.reproducibility import enable_reproducible_results from synthcity.utils.samplers import ImbalancedDatasetSampler diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 8eda1ea9..88df67f2 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -48,14 +48,16 @@ class TabDDPMPlugin(Plugin): L2 weight decay. batch_size: int = 1024 Size of mini-batches. - model_type: str = "mlp" - Type of model to use. Either "mlp" or "resnet". num_timesteps: int = 1000 Number of timesteps to use in the diffusion process. gaussian_loss_type: str = "mse" Type of loss to use for the Gaussian diffusion process. Either "mse" or "kl". scheduler: str = "cosine" The scheduler of forward process variance 'beta' to use. Either "cosine" or "linear". + model_type: str = "mlp" + Type of diffusion model to use ("mlp", "resnet", or "tabnet"). + model_params: dict = dict(n_layers_hidden=3, n_units_hidden=256, dropout=0.0) + Parameters of the diffusion model. Should be different for different model types. device: Any = DEVICE Device to use for training. callbacks: Sequence[Callback] = () @@ -101,7 +103,6 @@ def __init__( lr: float = 0.002, weight_decay: float = 1e-4, batch_size: int = 1024, - model_type: str = "mlp", num_timesteps: int = 1000, gaussian_loss_type: str = "mse", scheduler: str = "cosine", @@ -109,11 +110,11 @@ def __init__( callbacks: Sequence[Callback] = (), log_interval: int = 100, print_interval: int = 500, - # model params - n_layers_hidden: int = 3, - dim_hidden: int = 256, - dropout: float = 0.0, + model_type: str = "mlp", + model_params: dict = {}, dim_embed: int = 128, + continuous_encoder: str = "quantile", + cont_encoder_params: dict = {}, # core plugin arguments random_state: int = 0, workspace: Path = Path("workspace"), @@ -132,10 +133,6 @@ def __init__( self.is_classification = is_classification - mlp_params = dict( - n_layers_hidden=n_layers_hidden, n_units_hidden=dim_hidden, dropout=dropout - ) - self.model = TabDDPM( n_iter=n_iter, lr=lr, @@ -150,14 +147,17 @@ def __init__( log_interval=log_interval, print_interval=print_interval, model_type=model_type, - mlp_params=mlp_params, + model_params=model_params.copy(), dim_embed=dim_embed, ) + cont_encoder_params = cont_encoder_params.copy() + cont_encoder_params.update(random_state=random_state) + self.encoder = TabularEncoder( - continuous_encoder="quantile", - categorical_encoder="passthrough", - cont_encoder_params=dict(random_state=random_state), + continuous_encoder=continuous_encoder, + cont_encoder_params=cont_encoder_params, + categorical_encoder="none", cat_encoder_params=dict(), ) diff --git a/tests/plugins/core/models/test_mlp.py b/tests/plugins/core/models/test_mlp.py index e70b76de..ac9a2db3 100644 --- a/tests/plugins/core/models/test_mlp.py +++ b/tests/plugins/core/models/test_mlp.py @@ -5,12 +5,8 @@ from sklearn.datasets import load_diabetes, load_digits # synthcity absolute -from synthcity.plugins.core.models.mlp import ( - MLP, - LinearLayer, - MultiActivationHead, - ResidualLayer, -) +from synthcity.plugins.core.models.layers import MultiActivationHead +from synthcity.plugins.core.models.mlp import MLP, LinearLayer, ResidualLayer def test_network_config() -> None: diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index 2f9afeae..d7f93da3 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -19,10 +19,12 @@ plugin_name = "ddpm" plugin_params = dict( n_iter=1000, - batch_size=200, - num_timesteps=500, + batch_size=1000, + num_timesteps=100, log_interval=10, print_interval=100, + model_type="tabnet", + # model_params=dict() ) From 6e58cf3b97fd79e303e72834793dfd553fc2a515 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 2 Apr 2023 23:20:38 +0200 Subject: [PATCH 38/95] update docstrings and refactor --- src/synthcity/plugins/core/models/tabnet.py | 491 ++++++++----------- src/synthcity/plugins/generic/plugin_ddpm.py | 18 +- tests/plugins/generic/test_ddpm.py | 11 +- 3 files changed, 210 insertions(+), 310 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabnet.py b/src/synthcity/plugins/core/models/tabnet.py index 5a4f3051..c7e41c52 100644 --- a/src/synthcity/plugins/core/models/tabnet.py +++ b/src/synthcity/plugins/core/models/tabnet.py @@ -9,98 +9,19 @@ # synthcity relative from .layers import Entmax, Sparsemax -# class TabNet(torch.nn.Module): -# def __init__( -# self, -# input_dim, -# output_dim, -# n_d=8, -# n_a=8, -# n_steps=3, -# gamma=1.3, -# n_independent=2, -# n_shared=2, -# epsilon=1e-15, -# virtual_batch_size=128, -# momentum=0.02, -# mask_type="sparsemax", -# group_attention_matrix=None, -# ): -# """ -# Defines TabNet network -# Parameters -# ---------- -# input_dim : int -# Initial number of features -# output_dim : int -# Dimension of network output -# examples : one for regression, 2 for binary classification etc... -# n_d : int -# Dimension of the prediction layer (usually between 4 and 64) -# n_a : int -# Dimension of the attention layer (usually between 4 and 64) -# n_steps : int -# Number of successive steps in the network (usually between 3 and 10) -# gamma : float -# Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) -# n_independent : int -# Number of independent GLU layer in each GLU block (default 2) -# n_shared : int -# Number of independent GLU layer in each GLU block (default 2) -# epsilon : float -# Avoid log(0), this should be kept very low -# virtual_batch_size : int -# Batch size for Ghost Batch Normalization -# momentum : float -# Float value between 0 and 1 which will be used for momentum in all batch norm -# mask_type : str -# Either "sparsemax" or "entmax" : this is the masking function to use -# group_attention_matrix : torch matrix -# Matrix of size (n_groups, input_dim), m_ij = importance within group i of feature j -# """ -# super(TabNet, self).__init__() - -# if group_attention_matrix is None: -# group_attention_matrix = torch.Tensor([]) - -# self.input_dim = input_dim -# self.output_dim = output_dim -# self.n_d = n_d -# self.n_a = n_a -# self.n_steps = n_steps -# self.gamma = gamma -# self.epsilon = epsilon -# self.n_independent = n_independent -# self.n_shared = n_shared -# self.mask_type = mask_type - -# self.virtual_batch_size = virtual_batch_size -# self.post_embed_dim = self.embedder.post_embed_dim - -# self.tabnet = TabNetNoEmbeddings( -# self.post_embed_dim, -# output_dim, -# n_d, -# n_a, -# n_steps, -# gamma, -# n_independent, -# n_shared, -# epsilon, -# virtual_batch_size, -# momentum, -# mask_type, -# self.embedder.embedding_group_matrix, -# ) +def initialize_non_glu(module: Linear, input_dim: int, output_dim: int) -> None: + gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(4 * input_dim)) + torch.nn.init.xavier_normal_(module.weight, gain=gain_value) + # torch.nn.init.zeros_(module.bias) + return -# def forward(self, x): -# x = self.embedder(x) -# return self.tabnet(x) -# def forward_masks(self, x): -# x = self.embedder(x) -# return self.tabnet.forward_masks(x) +def initialize_glu(module: Linear, input_dim: int, output_dim: int) -> None: + gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(input_dim)) + torch.nn.init.xavier_normal_(module.weight, gain=gain_value) + # torch.nn.init.zeros_(module.bias) + return class TabNet(torch.nn.Module): @@ -121,7 +42,7 @@ def __init__( group_attention_matrix: Optional[torch.Tensor] = None, ) -> None: """ - Defines main part of the TabNet network without the embedding layers. + Defines main part of the TabNet network. Parameters ---------- @@ -131,9 +52,9 @@ def __init__( Dimension of network output examples : one for regression, 2 for binary classification etc... n_d : int - Dimension of the prediction layer (usually between 4 and 64) + Dimension of the prediction layer (usually between 4 and 64) n_a : int - Dimension of the attention layer (usually between 4 and 64) + Dimension of the attention layer (usually between 4 and 64) n_steps : int Number of successive steps in the network (usually between 3 and 10) gamma : float @@ -202,20 +123,6 @@ def forward_masks(self, x: torch.Tensor) -> torch.Tensor: return self.encoder.forward_masks(x) -def initialize_non_glu(module: Linear, input_dim: int, output_dim: int) -> None: - gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(4 * input_dim)) - torch.nn.init.xavier_normal_(module.weight, gain=gain_value) - # torch.nn.init.zeros_(module.bias) - return - - -def initialize_glu(module: Linear, input_dim: int, output_dim: int) -> None: - gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(input_dim)) - torch.nn.init.xavier_normal_(module.weight, gain=gain_value) - # torch.nn.init.zeros_(module.bias) - return - - class GBN(torch.nn.Module): """ Ghost Batch Normalization @@ -412,191 +319,6 @@ def forward_masks(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tenso return M_explain, masks -# class TabNetDecoder(torch.nn.Module): -# def __init__( -# self, -# input_dim, -# n_d=8, -# n_steps=3, -# n_independent=1, -# n_shared=1, -# virtual_batch_size=128, -# momentum=0.02, -# ): -# """ -# Defines main part of the TabNet network without the embedding layers. - -# Parameters -# ---------- -# input_dim : int -# Number of features -# output_dim : int or list of int for multi task classification -# Dimension of network output -# examples : one for regression, 2 for binary classification etc... -# n_d : int -# Dimension of the prediction layer (usually between 4 and 64) -# n_steps : int -# Number of successive steps in the network (usually between 3 and 10) -# gamma : float -# Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) -# n_independent : int -# Number of independent GLU layer in each GLU block (default 1) -# n_shared : int -# Number of independent GLU layer in each GLU block (default 1) -# virtual_batch_size : int -# Batch size for Ghost Batch Normalization -# momentum : float -# Float value between 0 and 1 which will be used for momentum in all batch norm -# """ -# super(TabNetDecoder, self).__init__() -# self.input_dim = input_dim -# self.n_d = n_d -# self.n_steps = n_steps -# self.n_independent = n_independent -# self.n_shared = n_shared -# self.virtual_batch_size = virtual_batch_size - -# self.feat_transformers = torch.nn.ModuleList() - -# if self.n_shared > 0: -# shared_feat_transform = torch.nn.ModuleList() -# for i in range(self.n_shared): -# if i == 0: -# shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) -# else: -# shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) - -# else: -# shared_feat_transform = None - -# for step in range(n_steps): -# transformer = FeatTransformer( -# n_d, -# n_d, -# shared_feat_transform, -# n_glu_independent=self.n_independent, -# virtual_batch_size=self.virtual_batch_size, -# momentum=momentum, -# ) -# self.feat_transformers.append(transformer) - -# self.reconstruction_layer = Linear(n_d, self.input_dim, bias=False) -# initialize_non_glu(self.reconstruction_layer, n_d, self.input_dim) - -# def forward(self, steps_output): -# res = 0 -# for step_nb, step_output in enumerate(steps_output): -# x = self.feat_transformers[step_nb](step_output) -# res = torch.add(res, x) -# res = self.reconstruction_layer(res) -# return res - - -# class TabNetPretraining(torch.nn.Module): -# def __init__( -# self, -# input_dim, -# pretraining_ratio=0.2, -# n_d=8, -# n_a=8, -# n_steps=3, -# gamma=1.3, -# cat_idxs=[], -# cat_dims=[], -# cat_emb_dim=1, -# n_independent=2, -# n_shared=2, -# epsilon=1e-15, -# virtual_batch_size=128, -# momentum=0.02, -# mask_type="sparsemax", -# n_shared_decoder=1, -# n_indep_decoder=1, -# group_attention_matrix=None, -# ): -# super(TabNetPretraining, self).__init__() - -# self.cat_idxs = cat_idxs or [] -# self.cat_dims = cat_dims or [] -# self.cat_emb_dim = cat_emb_dim - -# self.input_dim = input_dim -# self.n_d = n_d -# self.n_a = n_a -# self.n_steps = n_steps -# self.gamma = gamma -# self.epsilon = epsilon -# self.n_independent = n_independent -# self.n_shared = n_shared -# self.mask_type = mask_type -# self.pretraining_ratio = pretraining_ratio -# self.n_shared_decoder = n_shared_decoder -# self.n_indep_decoder = n_indep_decoder - -# if self.n_steps <= 0: -# raise ValueError("n_steps should be a positive integer.") -# if self.n_independent == 0 and self.n_shared == 0: -# raise ValueError("n_shared and n_independent can't be both zero.") - -# self.virtual_batch_size = virtual_batch_size -# self.embedder = EmbeddingGenerator( -# input_dim, cat_dims, cat_idxs, cat_emb_dim, group_attention_matrix -# ) -# self.post_embed_dim = self.embedder.post_embed_dim - -# self.masker = RandomObfuscator( -# self.pretraining_ratio, group_matrix=self.embedder.embedding_group_matrix -# ) -# self.encoder = TabNetEncoder( -# input_dim=self.post_embed_dim, -# output_dim=self.post_embed_dim, -# n_d=n_d, -# n_a=n_a, -# n_steps=n_steps, -# gamma=gamma, -# n_independent=n_independent, -# n_shared=n_shared, -# epsilon=epsilon, -# virtual_batch_size=virtual_batch_size, -# momentum=momentum, -# mask_type=mask_type, -# group_attention_matrix=self.embedder.embedding_group_matrix, -# ) -# self.decoder = TabNetDecoder( -# self.post_embed_dim, -# n_d=n_d, -# n_steps=n_steps, -# n_independent=self.n_indep_decoder, -# n_shared=self.n_shared_decoder, -# virtual_batch_size=virtual_batch_size, -# momentum=momentum, -# ) - -# def forward(self, x): -# """ -# Returns: res, embedded_x, obf_vars -# res : output of reconstruction -# embedded_x : embedded input -# obf_vars : which variable where obfuscated -# """ -# embedded_x = self.embedder(x) -# if self.training: -# masked_x, obfuscated_groups, obfuscated_vars = self.masker(embedded_x) -# # set prior of encoder with obfuscated groups -# prior = 1 - obfuscated_groups -# steps_out, _ = self.encoder(masked_x, prior=prior) -# res = self.decoder(steps_out) -# return res, embedded_x, obfuscated_vars -# else: -# steps_out, _ = self.encoder(embedded_x) -# res = self.decoder(steps_out) -# return res, embedded_x, torch.ones(embedded_x.shape).to(x.device) - -# def forward_masks(self, x): -# embedded_x = self.embedder(x) -# return self.encoder.forward_masks(embedded_x) - - class AttentiveTransformer(torch.nn.Module): def __init__( self, @@ -789,9 +511,194 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out +# class TabNetDecoder(torch.nn.Module): +# def __init__( +# self, +# input_dim, +# n_d=8, +# n_steps=3, +# n_independent=1, +# n_shared=1, +# virtual_batch_size=128, +# momentum=0.02, +# ): +# """ +# Defines main part of the TabNet network without the embedding layers. + +# Parameters +# ---------- +# input_dim : int +# Number of features +# output_dim : int or list of int for multi task classification +# Dimension of network output +# examples : one for regression, 2 for binary classification etc... +# n_d : int +# Dimension of the prediction layer (usually between 4 and 64) +# n_steps : int +# Number of successive steps in the network (usually between 3 and 10) +# gamma : float +# Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) +# n_independent : int +# Number of independent GLU layer in each GLU block (default 1) +# n_shared : int +# Number of independent GLU layer in each GLU block (default 1) +# virtual_batch_size : int +# Batch size for Ghost Batch Normalization +# momentum : float +# Float value between 0 and 1 which will be used for momentum in all batch norm +# """ +# super(TabNetDecoder, self).__init__() +# self.input_dim = input_dim +# self.n_d = n_d +# self.n_steps = n_steps +# self.n_independent = n_independent +# self.n_shared = n_shared +# self.virtual_batch_size = virtual_batch_size + +# self.feat_transformers = torch.nn.ModuleList() + +# if self.n_shared > 0: +# shared_feat_transform = torch.nn.ModuleList() +# for i in range(self.n_shared): +# if i == 0: +# shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) +# else: +# shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) + +# else: +# shared_feat_transform = None + +# for step in range(n_steps): +# transformer = FeatTransformer( +# n_d, +# n_d, +# shared_feat_transform, +# n_glu_independent=self.n_independent, +# virtual_batch_size=self.virtual_batch_size, +# momentum=momentum, +# ) +# self.feat_transformers.append(transformer) + +# self.reconstruction_layer = Linear(n_d, self.input_dim, bias=False) +# initialize_non_glu(self.reconstruction_layer, n_d, self.input_dim) + +# def forward(self, steps_output): +# res = 0 +# for step_nb, step_output in enumerate(steps_output): +# x = self.feat_transformers[step_nb](step_output) +# res = torch.add(res, x) +# res = self.reconstruction_layer(res) +# return res + + +# class TabNetPretraining(torch.nn.Module): +# def __init__( +# self, +# input_dim, +# pretraining_ratio=0.2, +# n_d=8, +# n_a=8, +# n_steps=3, +# gamma=1.3, +# cat_idxs=[], +# cat_dims=[], +# cat_emb_dim=1, +# n_independent=2, +# n_shared=2, +# epsilon=1e-15, +# virtual_batch_size=128, +# momentum=0.02, +# mask_type="sparsemax", +# n_shared_decoder=1, +# n_indep_decoder=1, +# group_attention_matrix=None, +# ): +# super(TabNetPretraining, self).__init__() + +# self.cat_idxs = cat_idxs or [] +# self.cat_dims = cat_dims or [] +# self.cat_emb_dim = cat_emb_dim + +# self.input_dim = input_dim +# self.n_d = n_d +# self.n_a = n_a +# self.n_steps = n_steps +# self.gamma = gamma +# self.epsilon = epsilon +# self.n_independent = n_independent +# self.n_shared = n_shared +# self.mask_type = mask_type +# self.pretraining_ratio = pretraining_ratio +# self.n_shared_decoder = n_shared_decoder +# self.n_indep_decoder = n_indep_decoder + +# if self.n_steps <= 0: +# raise ValueError("n_steps should be a positive integer.") +# if self.n_independent == 0 and self.n_shared == 0: +# raise ValueError("n_shared and n_independent can't be both zero.") + +# self.virtual_batch_size = virtual_batch_size +# self.embedder = EmbeddingGenerator( +# input_dim, cat_dims, cat_idxs, cat_emb_dim, group_attention_matrix +# ) +# self.post_embed_dim = self.embedder.post_embed_dim + +# self.masker = RandomObfuscator( +# self.pretraining_ratio, group_matrix=self.embedder.embedding_group_matrix +# ) +# self.encoder = TabNetEncoder( +# input_dim=self.post_embed_dim, +# output_dim=self.post_embed_dim, +# n_d=n_d, +# n_a=n_a, +# n_steps=n_steps, +# gamma=gamma, +# n_independent=n_independent, +# n_shared=n_shared, +# epsilon=epsilon, +# virtual_batch_size=virtual_batch_size, +# momentum=momentum, +# mask_type=mask_type, +# group_attention_matrix=self.embedder.embedding_group_matrix, +# ) +# self.decoder = TabNetDecoder( +# self.post_embed_dim, +# n_d=n_d, +# n_steps=n_steps, +# n_independent=self.n_indep_decoder, +# n_shared=self.n_shared_decoder, +# virtual_batch_size=virtual_batch_size, +# momentum=momentum, +# ) + +# def forward(self, x): +# """ +# Returns: res, embedded_x, obf_vars +# res : output of reconstruction +# embedded_x : embedded input +# obf_vars : which variable where obfuscated +# """ +# embedded_x = self.embedder(x) +# if self.training: +# masked_x, obfuscated_groups, obfuscated_vars = self.masker(embedded_x) +# # set prior of encoder with obfuscated groups +# prior = 1 - obfuscated_groups +# steps_out, _ = self.encoder(masked_x, prior=prior) +# res = self.decoder(steps_out) +# return res, embedded_x, obfuscated_vars +# else: +# steps_out, _ = self.encoder(embedded_x) +# res = self.decoder(steps_out) +# return res, embedded_x, torch.ones(embedded_x.shape).to(x.device) + +# def forward_masks(self, x): +# embedded_x = self.embedder(x) +# return self.encoder.forward_masks(embedded_x) + + # class EmbeddingGenerator(torch.nn.Module): # """ -# Classical embeddings generator +# Categorical embeddings generator # """ # def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dims, group_matrix): diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 88df67f2..09826a97 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -66,12 +66,6 @@ class TabDDPMPlugin(Plugin): Number of iterations between logging. print_interval: int = 500 Number of iterations between printing. - n_layers_hidden: int = 3 - Number of hidden layers in the MLP. - dim_hidden: int = 256 - Number of hidden units per hidden layer in the MLP. - dropout: float = 0.0 - Dropout rate. dim_embed: int = 128 Dimensionality of the embedding space. random_state: int @@ -191,8 +185,8 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]: LogIntDistribution(name="batch_size", low=256, high=4096), IntegerDistribution(name="num_timesteps", low=10, high=1000), LogIntDistribution(name="n_iter", low=1000, high=10000), - IntegerDistribution(name="n_layers_hidden", low=2, high=8), - LogIntDistribution(name="dim_hidden", low=128, high=1024), + # IntegerDistribution(name="n_layers_hidden", low=2, high=8), + # LogIntDistribution(name="dim_hidden", low=128, high=1024), ] def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": @@ -205,8 +199,6 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": If the task is regression, the target variable is not specially treated. There is no condition by default, but can be given by the user, either as a column name or an array-like. """ df = X.dataframe() - self.feature_names = df.columns - cond = kwargs.pop("cond", None) self.loss_history = None @@ -228,10 +220,8 @@ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": if cond is not None: if type(cond) is str: cond = df[cond] - self.expecting_conditional = True - - if cond is not None: cond = pd.Series(cond, index=df.index) + self.expecting_conditional = True # NOTE: cond may also be included in the dataframe self.model.fit(df, cond, **kwargs) @@ -254,7 +244,7 @@ def callback(count): # type: ignore df = self.encoder.inverse_transform(df) if self.is_classification: df = df.join(pd.Series(cond, name=self.target_name)) - return df[self.feature_names] # reorder columns + return df return self._safe_generate(callback, count, syn_schema, **kwargs) diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index d7f93da3..1e8766ac 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -21,11 +21,14 @@ n_iter=1000, batch_size=1000, num_timesteps=100, - log_interval=10, - print_interval=100, - model_type="tabnet", - # model_params=dict() + model_type="mlp", ) +# plugin_params = dict( +# n_iter=1000, +# batch_size=1000, +# num_timesteps=30, +# model_type="tabnet", +# ) def extend_fixtures( From 2a6ca6fde60cd1cc8c64746b31815dfa11de1a80 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 2 Apr 2023 23:44:57 +0200 Subject: [PATCH 39/95] fix type annotation compatibility --- src/synthcity/plugins/core/models/functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/synthcity/plugins/core/models/functions.py b/src/synthcity/plugins/core/models/functions.py index 27801034..51839e1a 100644 --- a/src/synthcity/plugins/core/models/functions.py +++ b/src/synthcity/plugins/core/models/functions.py @@ -2,7 +2,7 @@ Custom differentiable tensor functions. """ # stdlib -from typing import Any +from typing import Any, Tuple # third party import torch @@ -56,7 +56,7 @@ def forward( return output @staticmethod - def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: supp_size, output = ctx.saved_tensors dim = ctx.dim grad_input = grad_output.clone() @@ -70,7 +70,7 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: @staticmethod def _threshold_and_support( input: torch.Tensor, dim: int = -1 - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """Sparsemax building block: compute the threshold Parameters @@ -132,7 +132,7 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: @staticmethod def _threshold_and_support( input: torch.Tensor, dim: int = -1 - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: Xsrt, _ = torch.sort(input, descending=True, dim=dim) rho = _make_ix_like(input, dim) From 36acaa04853d6c94cd986809da9adcbff27d3c41 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 3 Apr 2023 13:45:22 +0200 Subject: [PATCH 40/95] make SkipConnection serializable --- src/synthcity/plugins/core/models/layers.py | 47 +++++++++------------ 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/src/synthcity/plugins/core/models/layers.py b/src/synthcity/plugins/core/models/layers.py index 9be97abd..54ccf5d0 100644 --- a/src/synthcity/plugins/core/models/layers.py +++ b/src/synthcity/plugins/core/models/layers.py @@ -44,41 +44,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.transpose(*self.dims) +@validate_arguments(config=dict(arbitrary_types_allowed=True)) +def _forward_skip_connection( + self: nn.Module, X: torch.Tensor, *args: Any, **kwargs: Any +) -> torch.Tensor: + # if X.shape[-1] == 0: + # return torch.zeros((*X.shape[:-1], self.n_units_out)).to(self.device) + X = X.float().to(self.device) + out = self._forward(X, *args, **kwargs) + return torch.cat([out, X], dim=-1) + + def SkipConnection(cls: Type[nn.Module]) -> Type[nn.Module]: """Wraps a model to add a skip connection from the input to the output. Example: >>> ResidualBlock = SkipConnection(MLP) - >>> ResidualBlock(n_units_in=10, n_units_out=3, n_units_hidden=64) - SkipConnection(MLP)( - (model): Sequential( - (0): LinearLayer( - (model): Sequential( - (0): Linear(in_features=10, out_features=64, bias=True) - (1): ReLU() - ) - ) - (1): Linear(in_features=64, out_features=3, bias=True) - ) - (loss): MSELoss() - ) + >>> res_block = ResidualBlock(n_units_in=10, n_units_out=3, n_units_hidden=64) + >>> res_block(torch.ones(10, 10)).shape + (10, 13) """ - class WrappedModule(cls): # type: ignore + class Wrapper(cls): # type: ignore device: torch.device = DEVICE - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def forward(self, X: torch.Tensor) -> torch.Tensor: - # if X.shape[-1] == 0: - # return torch.zeros((*X.shape[:-1], self.n_units_out)).to(self.device) - X = X.float().to(self.device) - out = super().forward(X) - return torch.cat([out, X], dim=-1) - - WrappedModule.__name__ = f"SkipConnection({cls.__name__})" - WrappedModule.__qualname__ = f"SkipConnection({cls.__qualname__})" - WrappedModule.__doc__ = f"""(With skipped connection) {cls.__doc__}""" - return WrappedModule + Wrapper._forward = cls.forward + Wrapper.forward = _forward_skip_connection + Wrapper.__name__ = f"SkipConnection({cls.__name__})" + Wrapper.__qualname__ = f"SkipConnection({cls.__qualname__})" + Wrapper.__doc__ = f"""(With skipped connection) {cls.__doc__}""" + return Wrapper # class GLU(nn.Module): From de15b9bcd4dbf2d29a4a90280caf695cb4108138 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 3 Apr 2023 18:29:19 +0200 Subject: [PATCH 41/95] fix TabularEncoder.activation_layout --- .gitignore | 2 +- .../plugins/core/models/feature_encoder.py | 7 +++---- .../plugins/core/models/tabular_encoder.py | 14 ++++++++++---- .../core/models/test_tabular_encoder.py | 18 +++++++----------- tests/plugins/generic/test_ddpm.py | 2 +- 5 files changed, 22 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 5195f6c2..20037c5f 100644 --- a/.gitignore +++ b/.gitignore @@ -67,4 +67,4 @@ lightning_logs generated MNIST cifar-10* -local_test.py +local_test*.py diff --git a/src/synthcity/plugins/core/models/feature_encoder.py b/src/synthcity/plugins/core/models/feature_encoder.py index 995a65a6..70807e31 100644 --- a/src/synthcity/plugins/core/models/feature_encoder.py +++ b/src/synthcity/plugins/core/models/feature_encoder.py @@ -172,6 +172,8 @@ def _inverse_transform(self, data: np.ndarray) -> np.ndarray: class BayesianGMMEncoder(FeatureEncoder): """Bayesian Gaussian Mixture encoder""" + n_dim_in = 2 + def __init__( self, n_components: int = 10, @@ -181,7 +183,6 @@ def __init__( std_multiplier: int = 4, ) -> None: self.n_components = n_components - self.random_state = random_state self.weight_threshold = weight_threshold self.clip_output = clip_output self.std_multiplier = std_multiplier @@ -190,13 +191,12 @@ def __init__( random_state=random_state, weight_concentration_prior=1e-3, ) - self.weights: List[float] def _fit(self, x: np.ndarray, **kwargs: Any) -> "BayesianGaussianMixture": self.min_value = x.min() self.max_value = x.max() - self.model.fit(x.reshape(-1, 1)) + self.model.fit(x) self.weights = self.model.weights_ self.means = self.model.means_.reshape(-1) self.stds = np.sqrt(self.model.covariances_).reshape(-1) @@ -204,7 +204,6 @@ def _fit(self, x: np.ndarray, **kwargs: Any) -> "BayesianGaussianMixture": return self def _transform(self, x: np.ndarray) -> np.ndarray: - x = x.reshape(-1, 1) means = self.means.reshape(1, -1) stds = self.stds.reshape(1, -1) diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index ad07a06c..364a6b57 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -274,10 +274,18 @@ def activation_layout( - discrete, and with length , the length of the one-hot encoding. """ out = [] + acts = dict(discrete=discrete_activation, continuous=continuous_activation) + # NOTE: be careful with the dim of softmax! for column_transform_info in self._column_transform_info_list: + ct = column_transform_info.trans_feature_types[0] + d = 0 for t in column_transform_info.trans_feature_types: - act = discrete_activation if t == "discrete" else continuous_activation - out.append((act, 1)) + if t != ct: + out.append((acts[ct], d)) + ct = t + d = 0 + d += 1 + out.append((acts[ct], d)) return out @@ -291,10 +299,8 @@ class BinEncoder(TabularEncoder): continuous_encoder = "bayesian_gmm" cont_encoder_params = dict(n_components=2) categorical_encoder = "passthrough" # "onehot" - # ! onehot encoder does not pass the tests cat_encoder_params = dict() # dict(handle_unknown="ignore", sparse=False) - # TODO: check if this is correct def _transform_feature( self, column_transform_info: FeatureInfo, feature: pd.Series ) -> pd.DataFrame: diff --git a/tests/plugins/core/models/test_tabular_encoder.py b/tests/plugins/core/models/test_tabular_encoder.py index 6837190a..9050c826 100644 --- a/tests/plugins/core/models/test_tabular_encoder.py +++ b/tests/plugins/core/models/test_tabular_encoder.py @@ -105,21 +105,17 @@ def check_equal_layouts( layout: list, act_layout: list, disc_act: str, cont_act: str ) -> None: expected_act_layout = [] + for col_info in layout: if col_info.feature_type == "continuous": - expected_act_layout.append(cont_act) - for _ in range(col_info.output_dimensions - 1): - expected_act_layout.append(disc_act) + expected_act_layout.append((cont_act, 1)) + r = col_info.output_dimensions - 1 + if r > 0: + expected_act_layout.append((disc_act, r)) else: - for _ in range(col_info.output_dimensions): - expected_act_layout.append(disc_act) - - expanded_act_layout = [] - for act, num in act_layout: - for _ in range(num): - expanded_act_layout.append(act) + expected_act_layout.append((disc_act, col_info.output_dimensions)) - assert expanded_act_layout == expected_act_layout + assert expected_act_layout == act_layout def test_encoder_activation_layout() -> None: diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index 1e8766ac..8fc4664a 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -120,7 +120,7 @@ def test_plugin_generate_constraints(test_plugin: Plugin) -> None: @pytest.mark.parametrize("test_plugin", extend_fixtures()) def test_plugin_hyperparams(test_plugin: Plugin) -> None: - assert len(test_plugin.hyperparameter_space()) == 6 + assert len(test_plugin.hyperparameter_space()) == 4 def test_sample_hyperparams() -> None: From 694cd223c7d6398d2ee3a1ace3998d1f0abce961 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 3 Apr 2023 22:06:30 +0200 Subject: [PATCH 42/95] remove unnecessary code --- src/synthcity/plugins/core/models/factory.py | 2 +- src/synthcity/plugins/core/models/mlp.py | 2 +- src/synthcity/plugins/core/models/tabnet.py | 320 ------------------ .../core/models/tabular_ddpm/__init__.py | 3 +- src/synthcity/utils/dataframe.py | 12 - 5 files changed, 4 insertions(+), 335 deletions(-) diff --git a/src/synthcity/plugins/core/models/factory.py b/src/synthcity/plugins/core/models/factory.py index e2d69525..586186d5 100644 --- a/src/synthcity/plugins/core/models/factory.py +++ b/src/synthcity/plugins/core/models/factory.py @@ -68,7 +68,7 @@ def _factory(type_: Union[str, type], params: dict, registry: dict) -> Any: if isinstance(type_, type): return type_(**params) - type_ = type_.lower().replace("_", "").replace("-", "") + type_ = type_.lower().replace("_", "") if type_ in registry: cls = registry[type_] if isinstance(cls, str): diff --git a/src/synthcity/plugins/core/models/mlp.py b/src/synthcity/plugins/core/models/mlp.py index 5d85c1c8..877dbe9c 100644 --- a/src/synthcity/plugins/core/models/mlp.py +++ b/src/synthcity/plugins/core/models/mlp.py @@ -111,9 +111,9 @@ class MLP(nn.Module): @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__( self, - *, n_units_in: int, n_units_out: int, + *, task_type: str = "regression", # classification/regression n_layers_hidden: int = 1, n_units_hidden: int = 100, diff --git a/src/synthcity/plugins/core/models/tabnet.py b/src/synthcity/plugins/core/models/tabnet.py index c7e41c52..c5fc702f 100644 --- a/src/synthcity/plugins/core/models/tabnet.py +++ b/src/synthcity/plugins/core/models/tabnet.py @@ -509,323 +509,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.bn(x) out = torch.mul(x[:, : self.output_dim], torch.sigmoid(x[:, self.output_dim :])) return out - - -# class TabNetDecoder(torch.nn.Module): -# def __init__( -# self, -# input_dim, -# n_d=8, -# n_steps=3, -# n_independent=1, -# n_shared=1, -# virtual_batch_size=128, -# momentum=0.02, -# ): -# """ -# Defines main part of the TabNet network without the embedding layers. - -# Parameters -# ---------- -# input_dim : int -# Number of features -# output_dim : int or list of int for multi task classification -# Dimension of network output -# examples : one for regression, 2 for binary classification etc... -# n_d : int -# Dimension of the prediction layer (usually between 4 and 64) -# n_steps : int -# Number of successive steps in the network (usually between 3 and 10) -# gamma : float -# Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) -# n_independent : int -# Number of independent GLU layer in each GLU block (default 1) -# n_shared : int -# Number of independent GLU layer in each GLU block (default 1) -# virtual_batch_size : int -# Batch size for Ghost Batch Normalization -# momentum : float -# Float value between 0 and 1 which will be used for momentum in all batch norm -# """ -# super(TabNetDecoder, self).__init__() -# self.input_dim = input_dim -# self.n_d = n_d -# self.n_steps = n_steps -# self.n_independent = n_independent -# self.n_shared = n_shared -# self.virtual_batch_size = virtual_batch_size - -# self.feat_transformers = torch.nn.ModuleList() - -# if self.n_shared > 0: -# shared_feat_transform = torch.nn.ModuleList() -# for i in range(self.n_shared): -# if i == 0: -# shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) -# else: -# shared_feat_transform.append(Linear(n_d, 2 * n_d, bias=False)) - -# else: -# shared_feat_transform = None - -# for step in range(n_steps): -# transformer = FeatTransformer( -# n_d, -# n_d, -# shared_feat_transform, -# n_glu_independent=self.n_independent, -# virtual_batch_size=self.virtual_batch_size, -# momentum=momentum, -# ) -# self.feat_transformers.append(transformer) - -# self.reconstruction_layer = Linear(n_d, self.input_dim, bias=False) -# initialize_non_glu(self.reconstruction_layer, n_d, self.input_dim) - -# def forward(self, steps_output): -# res = 0 -# for step_nb, step_output in enumerate(steps_output): -# x = self.feat_transformers[step_nb](step_output) -# res = torch.add(res, x) -# res = self.reconstruction_layer(res) -# return res - - -# class TabNetPretraining(torch.nn.Module): -# def __init__( -# self, -# input_dim, -# pretraining_ratio=0.2, -# n_d=8, -# n_a=8, -# n_steps=3, -# gamma=1.3, -# cat_idxs=[], -# cat_dims=[], -# cat_emb_dim=1, -# n_independent=2, -# n_shared=2, -# epsilon=1e-15, -# virtual_batch_size=128, -# momentum=0.02, -# mask_type="sparsemax", -# n_shared_decoder=1, -# n_indep_decoder=1, -# group_attention_matrix=None, -# ): -# super(TabNetPretraining, self).__init__() - -# self.cat_idxs = cat_idxs or [] -# self.cat_dims = cat_dims or [] -# self.cat_emb_dim = cat_emb_dim - -# self.input_dim = input_dim -# self.n_d = n_d -# self.n_a = n_a -# self.n_steps = n_steps -# self.gamma = gamma -# self.epsilon = epsilon -# self.n_independent = n_independent -# self.n_shared = n_shared -# self.mask_type = mask_type -# self.pretraining_ratio = pretraining_ratio -# self.n_shared_decoder = n_shared_decoder -# self.n_indep_decoder = n_indep_decoder - -# if self.n_steps <= 0: -# raise ValueError("n_steps should be a positive integer.") -# if self.n_independent == 0 and self.n_shared == 0: -# raise ValueError("n_shared and n_independent can't be both zero.") - -# self.virtual_batch_size = virtual_batch_size -# self.embedder = EmbeddingGenerator( -# input_dim, cat_dims, cat_idxs, cat_emb_dim, group_attention_matrix -# ) -# self.post_embed_dim = self.embedder.post_embed_dim - -# self.masker = RandomObfuscator( -# self.pretraining_ratio, group_matrix=self.embedder.embedding_group_matrix -# ) -# self.encoder = TabNetEncoder( -# input_dim=self.post_embed_dim, -# output_dim=self.post_embed_dim, -# n_d=n_d, -# n_a=n_a, -# n_steps=n_steps, -# gamma=gamma, -# n_independent=n_independent, -# n_shared=n_shared, -# epsilon=epsilon, -# virtual_batch_size=virtual_batch_size, -# momentum=momentum, -# mask_type=mask_type, -# group_attention_matrix=self.embedder.embedding_group_matrix, -# ) -# self.decoder = TabNetDecoder( -# self.post_embed_dim, -# n_d=n_d, -# n_steps=n_steps, -# n_independent=self.n_indep_decoder, -# n_shared=self.n_shared_decoder, -# virtual_batch_size=virtual_batch_size, -# momentum=momentum, -# ) - -# def forward(self, x): -# """ -# Returns: res, embedded_x, obf_vars -# res : output of reconstruction -# embedded_x : embedded input -# obf_vars : which variable where obfuscated -# """ -# embedded_x = self.embedder(x) -# if self.training: -# masked_x, obfuscated_groups, obfuscated_vars = self.masker(embedded_x) -# # set prior of encoder with obfuscated groups -# prior = 1 - obfuscated_groups -# steps_out, _ = self.encoder(masked_x, prior=prior) -# res = self.decoder(steps_out) -# return res, embedded_x, obfuscated_vars -# else: -# steps_out, _ = self.encoder(embedded_x) -# res = self.decoder(steps_out) -# return res, embedded_x, torch.ones(embedded_x.shape).to(x.device) - -# def forward_masks(self, x): -# embedded_x = self.embedder(x) -# return self.encoder.forward_masks(embedded_x) - - -# class EmbeddingGenerator(torch.nn.Module): -# """ -# Categorical embeddings generator -# """ - -# def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dims, group_matrix): -# """This is an embedding module for an entire set of features - -# Parameters -# ---------- -# input_dim : int -# Number of features coming as input (number of columns) -# cat_dims : list of int -# Number of modalities for each categorial features -# If the list is empty, no embeddings will be done -# cat_idxs : list of int -# Positional index for each categorical features in inputs -# cat_emb_dim : list of int -# Embedding dimension for each categorical features -# If int, the same embedding dimension will be used for all categorical features -# group_matrix : torch matrix -# Original group matrix before embeddings -# """ -# super(EmbeddingGenerator, self).__init__() - -# if cat_dims == [] and cat_idxs == []: -# self.skip_embedding = True -# self.post_embed_dim = input_dim -# self.embedding_group_matrix = group_matrix.to(group_matrix.device) -# return -# else: -# self.skip_embedding = False - -# self.post_embed_dim = int(input_dim + np.sum(cat_emb_dims) - len(cat_emb_dims)) - -# self.embeddings = torch.nn.ModuleList() - -# for cat_dim, emb_dim in zip(cat_dims, cat_emb_dims): -# self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim)) - -# # record continuous indices -# self.continuous_idx = torch.ones(input_dim, dtype=torch.bool) -# self.continuous_idx[cat_idxs] = 0 - -# # update group matrix -# n_groups = group_matrix.shape[0] -# self.embedding_group_matrix = torch.empty( -# (n_groups, self.post_embed_dim), device=group_matrix.device -# ) -# for group_idx in range(n_groups): -# post_emb_idx = 0 -# cat_feat_counter = 0 -# for init_feat_idx in range(input_dim): -# if self.continuous_idx[init_feat_idx] == 1: -# # this means that no embedding is applied to this column -# self.embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[ -# group_idx, init_feat_idx -# ] # noqa -# post_emb_idx += 1 -# else: -# # this is a categorical feature which creates multiple embeddings -# n_embeddings = cat_emb_dims[cat_feat_counter] -# self.embedding_group_matrix[ -# group_idx, post_emb_idx : post_emb_idx + n_embeddings -# ] = ( -# group_matrix[group_idx, init_feat_idx] / n_embeddings -# ) # noqa -# post_emb_idx += n_embeddings -# cat_feat_counter += 1 - -# def forward(self, x): -# """ -# Apply embeddings to inputs -# Inputs should be (batch_size, input_dim) -# Outputs will be of size (batch_size, self.post_embed_dim) -# """ -# if self.skip_embedding: -# # no embeddings required -# return x - -# cols = [] -# cat_feat_counter = 0 -# for feat_init_idx, is_continuous in enumerate(self.continuous_idx): -# # Enumerate through continuous idx boolean mask to apply embeddings -# if is_continuous: -# cols.append(x[:, feat_init_idx].float().view(-1, 1)) -# else: -# cols.append( -# self.embeddings[cat_feat_counter](x[:, feat_init_idx].long()) -# ) -# cat_feat_counter += 1 -# # concat -# post_embeddings = torch.cat(cols, dim=1) -# return post_embeddings - - -# class RandomObfuscator(torch.nn.Module): -# """ -# Create and applies obfuscation masks. -# The obfuscation is done at group level to match attention. -# """ - -# def __init__(self, pretraining_ratio, group_matrix): -# """ -# This create random obfuscation for self suppervised pretraining -# Parameters -# ---------- -# pretraining_ratio : float -# Ratio of feature to randomly discard for reconstruction - -# """ -# super(RandomObfuscator, self).__init__() -# self.pretraining_ratio = pretraining_ratio -# # group matrix is set to boolean here to pass all posssible information -# self.group_matrix = (group_matrix > 0) + 0.0 -# self.num_groups = group_matrix.shape[0] - -# def forward(self, x): -# """ -# Generate random obfuscation mask. - -# Returns -# ------- -# masked input and obfuscated variables. -# """ -# bs = x.shape[0] - -# obfuscated_groups = torch.bernoulli( -# self.pretraining_ratio * torch.ones((bs, self.num_groups), device=x.device) -# ) -# obfuscated_vars = torch.matmul(obfuscated_groups, self.group_matrix) -# masked_input = torch.mul(1 - obfuscated_vars, x) -# return masked_input, obfuscated_groups, obfuscated_vars diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 20a293b4..05edf404 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -88,8 +88,9 @@ def fit( if cat_cols: cat_cols, cat_counts = zip(*cat_cols) + num_cols = X.columns.difference(cat_cols) # reorder the columns so that the categorical ones go to the end - X = X[np.hstack([X.columns[~X.keys().isin(cat_cols)], cat_cols])] + X = X[num_cols.append(cat_cols)] self.feature_names_out = X.columns else: cat_counts = [0] diff --git a/src/synthcity/utils/dataframe.py b/src/synthcity/utils/dataframe.py index 80104e23..a313b91e 100644 --- a/src/synthcity/utils/dataframe.py +++ b/src/synthcity/utils/dataframe.py @@ -1,19 +1,7 @@ -# stdlib -import enum - # third party import pandas as pd -class TaskType(enum.Enum): - BINARY = "binary" - MULTICLASS = "multiclass" - REGRESSION = "regression" - - def __str__(self) -> str: - return self.value - - def constant_columns(dataframe: pd.DataFrame) -> list: """ Find constant value columns in a pandas dataframe. From a45978510f20915fc849b54948c99bcae8dd997a Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 6 Apr 2023 19:33:54 +0200 Subject: [PATCH 43/95] fix minor bug and add more nn models in factory --- setup.cfg | 3 +- src/synthcity/plugins/core/models/factory.py | 44 +- .../core/models/tabular_ddpm/__init__.py | 2 +- ...al8_tabular_modelling_with_diffusion.ipynb | 1061 ++++++++++------- 4 files changed, 682 insertions(+), 428 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7be378de..7e20f43c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,7 +58,8 @@ install_requires = monai tsai; python_version>"3.7" importlib-metadata; python_version<"3.8" - + igraph + pytest-cov [options.packages.find] where = src diff --git a/src/synthcity/plugins/core/models/factory.py b/src/synthcity/plugins/core/models/factory.py index 586186d5..a23ffc06 100644 --- a/src/synthcity/plugins/core/models/factory.py +++ b/src/synthcity/plugins/core/models/factory.py @@ -5,6 +5,14 @@ # third party from pydantic import validate_arguments from torch import nn +from tsai.models.InceptionTime import InceptionTime +from tsai.models.InceptionTimePlus import InceptionTimePlus +from tsai.models.OmniScaleCNN import OmniScaleCNN +from tsai.models.ResCNN import ResCNN +from tsai.models.RNN_FCN import MLSTM_FCN +from tsai.models.TCN import TCN +from tsai.models.XceptionTime import XceptionTime +from tsai.models.XCM import XCM # synthcity relative from .feature_encoder import ( @@ -23,11 +31,22 @@ # should only contain nn modules that can be used as building blocks in larger models MODELS = dict( mlp=".mlp.MLP", + # attention models + transformer=".transformer.TransformerModel", + tabnet=".tabnet.TabNet", + # rnn models rnn=nn.RNN, gru=nn.GRU, lstm=nn.LSTM, - transformer=".transformer.TransformerModel", - tabnet=".tabnet.TabNet", + # time series models + inceptiontime=InceptionTime, + omniscalecnn=OmniScaleCNN, + rescnn=ResCNN, + mlstmfcn=MLSTM_FCN, + tcn=TCN, + xceptiontime=XceptionTime, + xcm=XCM, + inceptiontimeplus=InceptionTimePlus, ) ACTIVATIONS = dict( @@ -74,7 +93,7 @@ def _factory(type_: Union[str, type], params: dict, registry: dict) -> Any: if isinstance(cls, str): cls = registry[type_] = _dynamic_import(cls) return cls(**params) - raise ValueError + raise ValueError(f"Unknown type: {type_}") def _dynamic_import(path: str) -> type: @@ -83,9 +102,9 @@ def _dynamic_import(path: str) -> type: package = __name__.rsplit(".", 1)[0] else: package = None - mod_path, cls = path.rsplit(".", 1) + mod_path, name = path.rsplit(".", 1) module = import_module(mod_path, package) - return getattr(module, cls) + return getattr(module, name) @validate_arguments(config=dict(arbitrary_types_allowed=True)) @@ -99,19 +118,13 @@ def get_model(block: Union[str, type], params: dict) -> Any: - transformer - tabnet """ - try: - return _factory(block, params, MODELS) - except ValueError: - raise ValueError(f"Unknown nn model: {block}") + return _factory(block, params, MODELS) @validate_arguments(config=dict(arbitrary_types_allowed=True)) def get_nonlin(nonlin: Union[str, nn.Module], params: dict = {}) -> Any: """Get a nonlinearity layer from a name or a class.""" - try: - return _factory(nonlin, params, ACTIVATIONS) - except ValueError: - raise ValueError(f"Unknown nonlinearity: {nonlin}") + return _factory(nonlin, params, ACTIVATIONS) @validate_arguments(config=dict(arbitrary_types_allowed=True)) @@ -131,7 +144,4 @@ def get_feature_encoder(encoder: Union[str, type], params: dict = {}) -> Any: """ if isinstance(encoder, type): # custom encoder encoder = FeatureEncoder.wraps(encoder) - try: - return _factory(encoder, params, FEATURE_ENCODERS) - except ValueError: - raise ValueError(f"Unknown feature encoder: {encoder}") + return _factory(encoder, params, FEATURE_ENCODERS) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index 05edf404..713f815e 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -90,7 +90,7 @@ def fit( cat_cols, cat_counts = zip(*cat_cols) num_cols = X.columns.difference(cat_cols) # reorder the columns so that the categorical ones go to the end - X = X[num_cols.append(cat_cols)] + X = X[list(num_cols) + list(cat_cols)] self.feature_names_out = X.columns else: cat_counts = [0] diff --git a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb index d73d0f60..d07618a1 100644 --- a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb +++ b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb @@ -217,26 +217,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2023-03-31T01:04:28.062034+0200][12004][INFO] Encoding sepal length (cm) 8461685668942494555\n", - "[2023-03-31T01:04:28.068034+0200][12004][INFO] Encoding sepal width (cm) 7372477013158199918\n", - "[2023-03-31T01:04:28.074037+0200][12004][INFO] Encoding petal length (cm) 8795408021141068254\n", - "[2023-03-31T01:04:28.081036+0200][12004][INFO] Encoding petal width (cm) 1839870727438321343\n", - "[2023-03-31T01:04:29.905425+0200][12004][INFO] Step 100: MLoss: 0.0 GLoss: 0.3103 Sum: 0.3103\n", - "[2023-03-31T01:04:31.486761+0200][12004][INFO] Step 200: MLoss: 0.0 GLoss: 0.3111 Sum: 0.3111\n", - "[2023-03-31T01:04:33.076905+0200][12004][INFO] Step 300: MLoss: 0.0 GLoss: 0.317 Sum: 0.317\n", - "[2023-03-31T01:04:34.611746+0200][12004][INFO] Step 400: MLoss: 0.0 GLoss: 0.3009 Sum: 0.3009\n", - "[2023-03-31T01:04:36.176039+0200][12004][INFO] Step 500: MLoss: 0.0 GLoss: 0.3154 Sum: 0.3154\n", - "[2023-03-31T01:04:37.956754+0200][12004][INFO] Step 600: MLoss: 0.0 GLoss: 0.3055 Sum: 0.3055\n", - "[2023-03-31T01:04:39.561269+0200][12004][INFO] Step 700: MLoss: 0.0 GLoss: 0.2917 Sum: 0.2917\n", - "[2023-03-31T01:04:41.195544+0200][12004][INFO] Step 800: MLoss: 0.0 GLoss: 0.2817 Sum: 0.2817\n", - "[2023-03-31T01:04:42.967236+0200][12004][INFO] Step 900: MLoss: 0.0 GLoss: 0.266 Sum: 0.266\n", - "[2023-03-31T01:04:44.913448+0200][12004][INFO] Step 1000: MLoss: 0.0 GLoss: 0.2793 Sum: 0.2793\n" + "[2023-04-06T19:07:53.035827+0200][45392][INFO] Encoding sepal length (cm) 8461685668942494555\n", + "[2023-04-06T19:07:53.045457+0200][45392][INFO] Encoding sepal width (cm) 7372477013158199918\n", + "[2023-04-06T19:07:53.054429+0200][45392][INFO] Encoding petal length (cm) 8795408021141068254\n", + "[2023-04-06T19:07:53.066673+0200][45392][INFO] Encoding petal width (cm) 1839870727438321343\n", + "[2023-04-06T19:07:55.483186+0200][45392][INFO] Step 100: MLoss: 0.0 GLoss: 0.3032 Sum: 0.3032\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 4, @@ -253,11 +244,13 @@ " weight_decay = 1e-4,\n", " batch_size = 1000,\n", " model_type = \"mlp\", # or \"resnet\"\n", + " model_params = dict(\n", + " n_layers_hidden = 3,\n", + " n_units_hidden = 256,\n", + " dropout = 0.0,\n", + " ),\n", " num_timesteps = 500, # timesteps in diffusion\n", - " n_layers_hidden = 3,\n", - " dim_hidden = 256,\n", " dim_embed = 128,\n", - " dropout = 0.0,\n", " # performance logging\n", " log_interval = 10,\n", " print_interval = 100,\n", @@ -278,7 +271,7 @@ "text/plain": [ "TabDDPM(\n", " (diffusion): GaussianMultinomialDiffusion(\n", - " (denoise_fn): MLPDiffusion(\n", + " (denoise_fn): DiffusionModel(\n", " (emb_nonlin): SiLU()\n", " (proj): Linear(in_features=4, out_features=128, bias=True)\n", " (time_emb): TimeStepEmbedding(\n", @@ -315,7 +308,7 @@ " )\n", " )\n", " )\n", - " (ema_model): MLPDiffusion(\n", + " (ema_model): DiffusionModel(\n", " (emb_nonlin): SiLU()\n", " (proj): Linear(in_features=4, out_features=128, bias=True)\n", " (time_emb): TimeStepEmbedding(\n", @@ -372,7 +365,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -381,14 +374,12 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -445,82 +436,82 @@ " \n", " \n", " 0\n", - " 6.491386\n", - " 2.937301\n", - " 4.396537\n", - " 1.363964\n", + " 4.300000\n", + " 4.400000\n", + " 1.000000\n", + " 0.100000\n", " 1\n", " \n", " \n", " 1\n", - " 6.272807\n", - " 2.878930\n", - " 5.028617\n", - " 1.973149\n", + " 7.900000\n", + " 4.400000\n", + " 6.900000\n", + " 2.500000\n", " 2\n", " \n", " \n", " 2\n", - " 4.912787\n", - " 2.239502\n", - " 2.384605\n", - " 0.845205\n", + " 5.740312\n", + " 2.060491\n", + " 2.659118\n", + " 0.982462\n", " 1\n", " \n", " \n", " 3\n", - " 5.115768\n", - " 2.636920\n", - " 3.933653\n", - " 1.100583\n", + " 4.300000\n", + " 2.000000\n", + " 1.000000\n", + " 0.100002\n", " 1\n", " \n", " \n", " 4\n", - " 5.946947\n", - " 2.976103\n", - " 4.557983\n", - " 1.417799\n", + " 4.300000\n", + " 2.000000\n", + " 1.000000\n", + " 0.100000\n", " 1\n", " \n", " \n", " 5\n", - " 5.528565\n", - " 2.197114\n", - " 4.133016\n", - " 1.296019\n", + " 4.505079\n", + " 2.025100\n", + " 3.619546\n", + " 0.208050\n", " 1\n", " \n", " \n", " 6\n", - " 5.275113\n", - " 2.565652\n", - " 3.698843\n", - " 1.068934\n", + " 6.108867\n", + " 2.511960\n", + " 4.668570\n", + " 1.392304\n", " 1\n", " \n", " \n", " 7\n", - " 7.900000\n", - " 4.400000\n", - " 6.899995\n", - " 2.500000\n", - " 2\n", + " 5.536599\n", + " 2.549499\n", + " 4.443977\n", + " 1.345480\n", + " 1\n", " \n", " \n", " 8\n", - " 6.899334\n", - " 2.847685\n", - " 6.243627\n", - " 1.561012\n", + " 7.900000\n", + " 4.400000\n", + " 6.900000\n", + " 2.500000\n", " 2\n", " \n", " \n", " 9\n", - " 5.267148\n", - " 2.780006\n", - " 3.565531\n", - " 1.128439\n", + " 4.300000\n", + " 2.000000\n", + " 1.000000\n", + " 0.100000\n", " 1\n", " \n", " \n", @@ -529,16 +520,16 @@ ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", - "0 6.491386 2.937301 4.396537 1.363964 \n", - "1 6.272807 2.878930 5.028617 1.973149 \n", - "2 4.912787 2.239502 2.384605 0.845205 \n", - "3 5.115768 2.636920 3.933653 1.100583 \n", - "4 5.946947 2.976103 4.557983 1.417799 \n", - "5 5.528565 2.197114 4.133016 1.296019 \n", - "6 5.275113 2.565652 3.698843 1.068934 \n", - "7 7.900000 4.400000 6.899995 2.500000 \n", - "8 6.899334 2.847685 6.243627 1.561012 \n", - "9 5.267148 2.780006 3.565531 1.128439 \n", + "0 4.300000 4.400000 1.000000 0.100000 \n", + "1 7.900000 4.400000 6.900000 2.500000 \n", + "2 5.740312 2.060491 2.659118 0.982462 \n", + "3 4.300000 2.000000 1.000000 0.100002 \n", + "4 4.300000 2.000000 1.000000 0.100000 \n", + "5 4.505079 2.025100 3.619546 0.208050 \n", + "6 6.108867 2.511960 4.668570 1.392304 \n", + "7 5.536599 2.549499 4.443977 1.345480 \n", + "8 7.900000 4.400000 6.900000 2.500000 \n", + "9 4.300000 2.000000 1.000000 0.100000 \n", "\n", " target \n", "0 1 \n", @@ -548,7 +539,7 @@ "4 1 \n", "5 1 \n", "6 1 \n", - "7 2 \n", + "7 1 \n", "8 2 \n", "9 1 " ] @@ -608,75 +599,75 @@ " \n", " \n", " 0\n", - " 5.230361\n", - " 3.371515\n", - " 1.408195\n", - " 0.201252\n", + " 7.900000\n", + " 4.400000\n", + " 1.000000\n", + " 2.459848\n", " 0\n", " \n", " \n", " 1\n", - " 4.705658\n", - " 3.064075\n", - " 1.388975\n", - " 0.386298\n", + " 4.300000\n", + " 4.400000\n", + " 1.000000\n", + " 0.100000\n", " 0\n", " \n", " \n", " 2\n", - " 4.711709\n", - " 3.056369\n", - " 1.451635\n", - " 0.195365\n", + " 7.900000\n", + " 4.400000\n", + " 1.000000\n", + " 0.100000\n", " 0\n", " \n", " \n", " 3\n", - " 6.981074\n", - " 3.274333\n", - " 4.803886\n", - " 1.623058\n", + " 5.499909\n", + " 2.023536\n", + " 3.788866\n", + " 1.262810\n", " 1\n", " \n", " \n", " 4\n", - " 5.999308\n", - " 2.927207\n", - " 4.040594\n", - " 1.389657\n", + " 4.300000\n", + " 2.000000\n", + " 1.000000\n", + " 0.100000\n", " 1\n", " \n", " \n", " 5\n", - " 5.698102\n", - " 2.521559\n", - " 3.288451\n", - " 0.966808\n", + " 4.300000\n", + " 2.000000\n", + " 1.000000\n", + " 0.197491\n", " 1\n", " \n", " \n", " 6\n", - " 6.776549\n", - " 3.012238\n", - " 6.285867\n", - " 2.134174\n", + " 7.900000\n", + " 4.400000\n", + " 6.900000\n", + " 2.500000\n", " 2\n", " \n", " \n", " 7\n", - " 7.900000\n", + " 4.300000\n", " 4.400000\n", - " 6.896603\n", - " 2.500000\n", - " 2\n", + " 1.000000\n", + " 0.100000\n", + " 0\n", " \n", " \n", " 8\n", - " 7.900000\n", - " 4.400000\n", - " 6.898989\n", - " 2.500000\n", - " 2\n", + " 4.300000\n", + " 2.000000\n", + " 1.000001\n", + " 0.142259\n", + " 1\n", " \n", " \n", "\n", @@ -684,15 +675,15 @@ ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", - "0 5.230361 3.371515 1.408195 0.201252 \n", - "1 4.705658 3.064075 1.388975 0.386298 \n", - "2 4.711709 3.056369 1.451635 0.195365 \n", - "3 6.981074 3.274333 4.803886 1.623058 \n", - "4 5.999308 2.927207 4.040594 1.389657 \n", - "5 5.698102 2.521559 3.288451 0.966808 \n", - "6 6.776549 3.012238 6.285867 2.134174 \n", - "7 7.900000 4.400000 6.896603 2.500000 \n", - "8 7.900000 4.400000 6.898989 2.500000 \n", + "0 7.900000 4.400000 1.000000 2.459848 \n", + "1 4.300000 4.400000 1.000000 0.100000 \n", + "2 7.900000 4.400000 1.000000 0.100000 \n", + "3 5.499909 2.023536 3.788866 1.262810 \n", + "4 4.300000 2.000000 1.000000 0.100000 \n", + "5 4.300000 2.000000 1.000000 0.197491 \n", + "6 7.900000 4.400000 6.900000 2.500000 \n", + "7 4.300000 4.400000 1.000000 0.100000 \n", + "8 4.300000 2.000000 1.000001 0.142259 \n", "\n", " target \n", "0 0 \n", @@ -702,8 +693,8 @@ "4 1 \n", "5 1 \n", "6 2 \n", - "7 2 \n", - "8 2 " + "7 0 \n", + "8 1 " ] }, "execution_count": 8, @@ -941,7 +932,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "14bca1cd", "metadata": {}, "outputs": [ @@ -949,47 +940,31 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2023-03-31T01:04:50.377220+0200][12004][INFO] Encoding fixed acidity 8821222230854998919\n", - "[2023-03-31T01:04:50.427480+0200][12004][INFO] Encoding volatile acidity 3689048099044143611\n", - "[2023-03-31T01:04:50.442050+0200][12004][INFO] Encoding citric acid 735380040632581265\n", - "[2023-03-31T01:04:50.457233+0200][12004][INFO] Encoding residual sugar 2442409671939919968\n", - "[2023-03-31T01:04:50.473234+0200][12004][INFO] Encoding chlorides 7195838597182208600\n", - "[2023-03-31T01:04:50.488234+0200][12004][INFO] Encoding free sulfur dioxide 3309873879720413309\n", - "[2023-03-31T01:04:50.501098+0200][12004][INFO] Encoding total sulfur dioxide 8059822526963442530\n", - "[2023-03-31T01:04:50.512236+0200][12004][INFO] Encoding density 3625281346475756911\n", - "[2023-03-31T01:04:50.523222+0200][12004][INFO] Encoding pH 4552002723230490789\n", - "[2023-03-31T01:04:50.532220+0200][12004][INFO] Encoding sulphates 4957484118723629481\n", - "[2023-03-31T01:04:50.540983+0200][12004][INFO] Encoding alcohol 3711001505059098944\n", - "[2023-03-31T01:04:50.547987+0200][12004][INFO] Encoding quality 3457201635469827215\n", - "[2023-03-31T01:04:58.399971+0200][12004][INFO] Step 100: MLoss: 1.3342 GLoss: 0.9783 Sum: 2.3125\n", - "[2023-03-31T01:05:04.973385+0200][12004][INFO] Step 200: MLoss: 1.2858 GLoss: 0.9031 Sum: 2.1889000000000003\n", - "[2023-03-31T01:05:11.741000+0200][12004][INFO] Step 300: MLoss: 1.186 GLoss: 0.7758 Sum: 1.9618\n", - "[2023-03-31T01:05:18.619270+0200][12004][INFO] Step 400: MLoss: 1.1481 GLoss: 0.6615 Sum: 1.8095999999999999\n", - "[2023-03-31T01:05:24.930108+0200][12004][INFO] Step 500: MLoss: 1.1661 GLoss: 0.6094 Sum: 1.7755\n", - "[2023-03-31T01:05:31.651906+0200][12004][INFO] Step 600: MLoss: 1.1902 GLoss: 0.5381 Sum: 1.7283\n", - "[2023-03-31T01:05:38.246164+0200][12004][INFO] Step 700: MLoss: 1.1305 GLoss: 0.5087 Sum: 1.6392000000000002\n", - "[2023-03-31T01:05:44.776216+0200][12004][INFO] Step 800: MLoss: 1.1131 GLoss: 0.4832 Sum: 1.5963\n", - "[2023-03-31T01:05:51.917105+0200][12004][INFO] Step 900: MLoss: 1.1014 GLoss: 0.4786 Sum: 1.58\n", - "[2023-03-31T01:05:59.098745+0200][12004][INFO] Step 1000: MLoss: 1.1479 GLoss: 0.4707 Sum: 1.6185999999999998\n", - "[2023-03-31T01:06:05.690366+0200][12004][INFO] Step 1100: MLoss: 1.1712 GLoss: 0.4693 Sum: 1.6405\n", - "[2023-03-31T01:06:12.549553+0200][12004][INFO] Step 1200: MLoss: 1.1199 GLoss: 0.4611 Sum: 1.581\n", - "[2023-03-31T01:06:19.575478+0200][12004][INFO] Step 1300: MLoss: 1.1525 GLoss: 0.4614 Sum: 1.6139000000000001\n", - "[2023-03-31T01:06:26.641319+0200][12004][INFO] Step 1400: MLoss: 1.1164 GLoss: 0.4671 Sum: 1.5835000000000001\n", - "[2023-03-31T01:06:33.249503+0200][12004][INFO] Step 1500: MLoss: 1.1356 GLoss: 0.4577 Sum: 1.5933\n", - "[2023-03-31T01:06:40.025759+0200][12004][INFO] Step 1600: MLoss: 1.1367 GLoss: 0.4541 Sum: 1.5908\n", - "[2023-03-31T01:06:46.754777+0200][12004][INFO] Step 1700: MLoss: 1.0896 GLoss: 0.4524 Sum: 1.5419999999999998\n", - "[2023-03-31T01:06:54.036939+0200][12004][INFO] Step 1800: MLoss: 1.075 GLoss: 0.4471 Sum: 1.5221\n", - "[2023-03-31T01:07:00.554405+0200][12004][INFO] Step 1900: MLoss: 1.1154 GLoss: 0.4495 Sum: 1.5649\n", - "[2023-03-31T01:07:07.289610+0200][12004][INFO] Step 2000: MLoss: 1.266 GLoss: 0.454 Sum: 1.72\n" + "[2023-04-06T19:09:16.010623+0200][45392][INFO] Encoding fixed acidity 8821222230854998919\n", + "[2023-04-06T19:09:16.022381+0200][45392][INFO] Encoding volatile acidity 3689048099044143611\n", + "[2023-04-06T19:09:16.035202+0200][45392][INFO] Encoding citric acid 735380040632581265\n", + "[2023-04-06T19:09:16.046041+0200][45392][INFO] Encoding residual sugar 2442409671939919968\n", + "[2023-04-06T19:09:16.057037+0200][45392][INFO] Encoding chlorides 7195838597182208600\n", + "[2023-04-06T19:09:16.069198+0200][45392][INFO] Encoding free sulfur dioxide 3309873879720413309\n", + "[2023-04-06T19:09:16.079198+0200][45392][INFO] Encoding total sulfur dioxide 8059822526963442530\n", + "[2023-04-06T19:09:16.089218+0200][45392][INFO] Encoding density 3625281346475756911\n", + "[2023-04-06T19:09:16.100269+0200][45392][INFO] Encoding pH 4552002723230490789\n", + "[2023-04-06T19:09:16.108269+0200][45392][INFO] Encoding sulphates 4957484118723629481\n", + "[2023-04-06T19:09:16.118284+0200][45392][INFO] Encoding alcohol 3711001505059098944\n", + "[2023-04-06T19:09:16.128449+0200][45392][INFO] Encoding quality 3457201635469827215\n", + "[2023-04-06T19:09:23.491031+0200][45392][INFO] Step 100: MLoss: 1.3299 GLoss: 0.9771 Sum: 2.307\n", + "[2023-04-06T19:09:31.041829+0200][45392][INFO] Step 200: MLoss: 1.2726 GLoss: 0.9268 Sum: 2.1994\n", + "[2023-04-06T19:09:39.672711+0200][45392][INFO] Step 300: MLoss: 1.2003 GLoss: 0.8693 Sum: 2.0696\n", + "[2023-04-06T19:09:48.042293+0200][45392][INFO] Step 400: MLoss: 1.1578 GLoss: 0.8343 Sum: 1.9921\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -1002,8 +977,11 @@ " lr = 5e-4,\n", " weight_decay = 1e-4,\n", " batch_size = 1250,\n", - " n_layers_hidden = 3,\n", - " dim_hidden = 256,\n", + " model_params = dict(\n", + " n_layers_hidden = 3,\n", + " n_units_hidden = 256,\n", + " dropout = 0.0,\n", + " ),\n", " num_timesteps = 100, # timesteps in diffusion\n", ")\n", "plugin = Plugins().get(\"ddpm\", **plugin_params)\n", @@ -1012,30 +990,28 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "83064f94", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 11, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -1045,7 +1021,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "af9d6df1", "metadata": {}, "outputs": [ @@ -1087,123 +1063,123 @@ " \n", " \n", " 0\n", - " 7.400000\n", - " 0.080079\n", - " 0.000000\n", - " 65.800000\n", - " 0.025529\n", - " 67.119106\n", - " 440.000000\n", - " 1.038980\n", + " 14.2\n", + " 0.08\n", + " 1.66\n", + " 65.8\n", + " 0.346\n", + " 289.0\n", + " 9.0\n", + " 1.03898\n", " 2.720000\n", - " 1.080000\n", - " 8.0\n", - " 7\n", + " 0.22\n", + " 14.2\n", + " 6\n", " \n", " \n", " 1\n", - " 5.088797\n", - " 0.112499\n", - " 0.370000\n", - " 0.763677\n", - " 0.009000\n", - " 288.824821\n", - " 98.000000\n", - " 0.987110\n", - " 3.240000\n", - " 0.220000\n", " 14.2\n", - " 8\n", + " 1.10\n", + " 0.00\n", + " 0.6\n", + " 0.346\n", + " 289.0\n", + " 440.0\n", + " 0.98711\n", + " 2.720000\n", + " 1.08\n", + " 14.2\n", + " 7\n", " \n", " \n", " 2\n", - " 3.800000\n", - " 1.100000\n", - " 0.000000\n", - " 0.600000\n", - " 0.009000\n", - " 2.000000\n", - " 9.000000\n", - " 1.038980\n", + " 3.8\n", + " 0.08\n", + " 1.66\n", + " 65.8\n", + " 0.346\n", + " 289.0\n", + " 9.0\n", + " 0.98711\n", " 3.820000\n", - " 0.220000\n", + " 0.22\n", " 8.0\n", - " 4\n", + " 5\n", " \n", " \n", " 3\n", - " 3.800000\n", - " 0.080000\n", - " 1.659603\n", - " 0.600000\n", - " 0.034734\n", - " 2.000000\n", - " 9.000000\n", - " 0.987110\n", - " 3.775879\n", - " 1.080000\n", - " 9.5\n", + " 3.8\n", + " 0.08\n", + " 0.00\n", + " 0.6\n", + " 0.346\n", + " 289.0\n", + " 440.0\n", + " 0.98711\n", + " 2.720000\n", + " 0.22\n", + " 14.2\n", " 7\n", " \n", " \n", " 4\n", - " 5.700000\n", - " 0.330000\n", - " 0.213874\n", - " 10.937306\n", - " 0.050000\n", - " 39.064968\n", - " 147.790987\n", - " 0.997247\n", - " 3.330984\n", - " 0.380000\n", - " 8.7\n", - " 6\n", + " 14.2\n", + " 1.10\n", + " 0.00\n", + " 0.6\n", + " 0.346\n", + " 289.0\n", + " 9.0\n", + " 0.98711\n", + " 3.771223\n", + " 1.08\n", + " 14.2\n", + " 7\n", " \n", " \n", " 5\n", - " 14.200000\n", - " 0.080000\n", - " 0.000000\n", - " 0.600000\n", - " 0.009000\n", - " 2.000000\n", - " 9.000000\n", - " 0.987110\n", - " 2.916428\n", - " 0.220055\n", - " 9.5\n", - " 5\n", + " 14.2\n", + " 1.10\n", + " 1.66\n", + " 65.8\n", + " 0.009\n", + " 2.0\n", + " 440.0\n", + " 0.98711\n", + " 2.720000\n", + " 1.08\n", + " 14.2\n", + " 7\n", " \n", " \n", " 6\n", - " 14.200000\n", - " 0.087887\n", - " 1.660000\n", - " 0.600000\n", - " 0.108297\n", - " 49.000000\n", - " 65.466909\n", - " 0.987117\n", - " 2.720090\n", - " 0.220006\n", + " 3.8\n", + " 0.08\n", + " 0.00\n", + " 65.8\n", + " 0.346\n", + " 2.0\n", " 9.0\n", - " 6\n", + " 1.03898\n", + " 2.720000\n", + " 1.08\n", + " 8.0\n", + " 7\n", " \n", " \n", " 7\n", - " 8.870765\n", - " 1.099817\n", - " 1.657142\n", - " 12.921528\n", - " 0.025276\n", - " 288.846488\n", - " 438.337342\n", - " 0.996196\n", - " 2.724725\n", - " 0.220049\n", - " 10.2\n", - " 5\n", + " 14.2\n", + " 0.08\n", + " 1.66\n", + " 0.6\n", + " 0.346\n", + " 289.0\n", + " 9.0\n", + " 0.98711\n", + " 3.820000\n", + " 0.22\n", + " 14.2\n", + " 7\n", " \n", " \n", "\n", @@ -1211,37 +1187,37 @@ ], "text/plain": [ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", - "0 7.400000 0.080079 0.000000 65.800000 0.025529 \n", - "1 5.088797 0.112499 0.370000 0.763677 0.009000 \n", - "2 3.800000 1.100000 0.000000 0.600000 0.009000 \n", - "3 3.800000 0.080000 1.659603 0.600000 0.034734 \n", - "4 5.700000 0.330000 0.213874 10.937306 0.050000 \n", - "5 14.200000 0.080000 0.000000 0.600000 0.009000 \n", - "6 14.200000 0.087887 1.660000 0.600000 0.108297 \n", - "7 8.870765 1.099817 1.657142 12.921528 0.025276 \n", + "0 14.2 0.08 1.66 65.8 0.346 \n", + "1 14.2 1.10 0.00 0.6 0.346 \n", + "2 3.8 0.08 1.66 65.8 0.346 \n", + "3 3.8 0.08 0.00 0.6 0.346 \n", + "4 14.2 1.10 0.00 0.6 0.346 \n", + "5 14.2 1.10 1.66 65.8 0.009 \n", + "6 3.8 0.08 0.00 65.8 0.346 \n", + "7 14.2 0.08 1.66 0.6 0.346 \n", "\n", - " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", - "0 67.119106 440.000000 1.038980 2.720000 1.080000 \n", - "1 288.824821 98.000000 0.987110 3.240000 0.220000 \n", - "2 2.000000 9.000000 1.038980 3.820000 0.220000 \n", - "3 2.000000 9.000000 0.987110 3.775879 1.080000 \n", - "4 39.064968 147.790987 0.997247 3.330984 0.380000 \n", - "5 2.000000 9.000000 0.987110 2.916428 0.220055 \n", - "6 49.000000 65.466909 0.987117 2.720090 0.220006 \n", - "7 288.846488 438.337342 0.996196 2.724725 0.220049 \n", + " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", + "0 289.0 9.0 1.03898 2.720000 0.22 \n", + "1 289.0 440.0 0.98711 2.720000 1.08 \n", + "2 289.0 9.0 0.98711 3.820000 0.22 \n", + "3 289.0 440.0 0.98711 2.720000 0.22 \n", + "4 289.0 9.0 0.98711 3.771223 1.08 \n", + "5 2.0 440.0 0.98711 2.720000 1.08 \n", + "6 2.0 9.0 1.03898 2.720000 1.08 \n", + "7 289.0 9.0 0.98711 3.820000 0.22 \n", "\n", " alcohol quality \n", - "0 8.0 7 \n", - "1 14.2 8 \n", - "2 8.0 4 \n", - "3 9.5 7 \n", - "4 8.7 6 \n", - "5 9.5 5 \n", - "6 9.0 6 \n", - "7 10.2 5 " + "0 14.2 6 \n", + "1 14.2 7 \n", + "2 8.0 5 \n", + "3 14.2 7 \n", + "4 14.2 7 \n", + "5 14.2 7 \n", + "6 8.0 7 \n", + "7 14.2 7 " ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -1261,9 +1237,18 @@ "A conditional variable `cond` can be provided to the `fit` method. It can be either a column name in the dataset or a custom array. The model will then learn the conditional distribution of the dataset given `cond`. In this case, an array must be provided as the `cond` argument of the `generate` method." ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "191d1b6f", + "metadata": {}, + "source": [ + "Use a column name as the `cond` argument in the `fit` method:" + ] + }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "56a1fc7e", "metadata": {}, "outputs": [ @@ -1271,47 +1256,31 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2023-03-31T01:07:08.859587+0200][12004][INFO] Encoding fixed acidity 8821222230854998919\n", - "[2023-03-31T01:07:08.873767+0200][12004][INFO] Encoding volatile acidity 3689048099044143611\n", - "[2023-03-31T01:07:08.885765+0200][12004][INFO] Encoding citric acid 735380040632581265\n", - "[2023-03-31T01:07:08.896357+0200][12004][INFO] Encoding residual sugar 2442409671939919968\n", - "[2023-03-31T01:07:08.904579+0200][12004][INFO] Encoding chlorides 7195838597182208600\n", - "[2023-03-31T01:07:08.914577+0200][12004][INFO] Encoding free sulfur dioxide 3309873879720413309\n", - "[2023-03-31T01:07:08.922581+0200][12004][INFO] Encoding total sulfur dioxide 8059822526963442530\n", - "[2023-03-31T01:07:08.930580+0200][12004][INFO] Encoding density 3625281346475756911\n", - "[2023-03-31T01:07:08.939216+0200][12004][INFO] Encoding pH 4552002723230490789\n", - "[2023-03-31T01:07:08.947216+0200][12004][INFO] Encoding sulphates 4957484118723629481\n", - "[2023-03-31T01:07:08.956217+0200][12004][INFO] Encoding alcohol 3711001505059098944\n", - "[2023-03-31T01:07:08.964215+0200][12004][INFO] Encoding quality 3457201635469827215\n", - "[2023-03-31T01:07:17.078379+0200][12004][INFO] Step 100: MLoss: 0.9932 GLoss: 0.9775 Sum: 1.9707\n", - "[2023-03-31T01:07:24.055012+0200][12004][INFO] Step 200: MLoss: 0.2957 GLoss: 0.9254 Sum: 1.2211\n", - "[2023-03-31T01:07:32.461826+0200][12004][INFO] Step 300: MLoss: 0.0748 GLoss: 0.8407 Sum: 0.9155\n", - "[2023-03-31T01:07:39.522162+0200][12004][INFO] Step 400: MLoss: 0.0289 GLoss: 0.7444 Sum: 0.7733\n", - "[2023-03-31T01:07:47.110402+0200][12004][INFO] Step 500: MLoss: 0.0292 GLoss: 0.6655 Sum: 0.6947\n", - "[2023-03-31T01:07:54.622795+0200][12004][INFO] Step 600: MLoss: 0.0229 GLoss: 0.5844 Sum: 0.6073000000000001\n", - "[2023-03-31T01:08:01.951234+0200][12004][INFO] Step 700: MLoss: 0.0218 GLoss: 0.5572 Sum: 0.5790000000000001\n", - "[2023-03-31T01:08:09.957993+0200][12004][INFO] Step 800: MLoss: 0.0091 GLoss: 0.531 Sum: 0.5401\n", - "[2023-03-31T01:08:18.931373+0200][12004][INFO] Step 900: MLoss: 0.0114 GLoss: 0.5286 Sum: 0.5399999999999999\n", - "[2023-03-31T01:08:26.898063+0200][12004][INFO] Step 1000: MLoss: 0.0099 GLoss: 0.5259 Sum: 0.5358\n", - "[2023-03-31T01:08:34.593930+0200][12004][INFO] Step 1100: MLoss: 0.0106 GLoss: 0.5196 Sum: 0.5302\n", - "[2023-03-31T01:08:41.818482+0200][12004][INFO] Step 1200: MLoss: 0.0105 GLoss: 0.5072 Sum: 0.5176999999999999\n", - "[2023-03-31T01:08:49.426481+0200][12004][INFO] Step 1300: MLoss: 0.0086 GLoss: 0.5112 Sum: 0.5198\n", - "[2023-03-31T01:08:56.953344+0200][12004][INFO] Step 1400: MLoss: 0.0106 GLoss: 0.516 Sum: 0.5266000000000001\n", - "[2023-03-31T01:09:04.509760+0200][12004][INFO] Step 1500: MLoss: 0.0075 GLoss: 0.5062 Sum: 0.5136999999999999\n", - "[2023-03-31T01:09:11.742216+0200][12004][INFO] Step 1600: MLoss: 0.0098 GLoss: 0.5012 Sum: 0.511\n", - "[2023-03-31T01:09:19.870988+0200][12004][INFO] Step 1700: MLoss: 0.0088 GLoss: 0.499 Sum: 0.5078\n", - "[2023-03-31T01:09:27.578035+0200][12004][INFO] Step 1800: MLoss: 0.0163 GLoss: 0.4956 Sum: 0.5119\n", - "[2023-03-31T01:09:34.406045+0200][12004][INFO] Step 1900: MLoss: 0.0046 GLoss: 0.4955 Sum: 0.5001\n", - "[2023-03-31T01:09:41.645411+0200][12004][INFO] Step 2000: MLoss: 0.017 GLoss: 0.5008 Sum: 0.5178\n" + "[2023-04-06T19:10:28.307332+0200][45392][INFO] Encoding fixed acidity 8821222230854998919\n", + "[2023-04-06T19:10:28.316302+0200][45392][INFO] Encoding volatile acidity 3689048099044143611\n", + "[2023-04-06T19:10:28.328835+0200][45392][INFO] Encoding citric acid 735380040632581265\n", + "[2023-04-06T19:10:28.337818+0200][45392][INFO] Encoding residual sugar 2442409671939919968\n", + "[2023-04-06T19:10:28.346502+0200][45392][INFO] Encoding chlorides 7195838597182208600\n", + "[2023-04-06T19:10:28.355523+0200][45392][INFO] Encoding free sulfur dioxide 3309873879720413309\n", + "[2023-04-06T19:10:28.367907+0200][45392][INFO] Encoding total sulfur dioxide 8059822526963442530\n", + "[2023-04-06T19:10:28.379128+0200][45392][INFO] Encoding density 3625281346475756911\n", + "[2023-04-06T19:10:28.388190+0200][45392][INFO] Encoding pH 4552002723230490789\n", + "[2023-04-06T19:10:28.396086+0200][45392][INFO] Encoding sulphates 4957484118723629481\n", + "[2023-04-06T19:10:28.404089+0200][45392][INFO] Encoding alcohol 3711001505059098944\n", + "[2023-04-06T19:10:28.412665+0200][45392][INFO] Encoding quality 3457201635469827215\n", + "[2023-04-06T19:10:35.956508+0200][45392][INFO] Step 100: MLoss: 1.0404 GLoss: 0.9809 Sum: 2.0213\n", + "[2023-04-06T19:10:42.296829+0200][45392][INFO] Step 200: MLoss: 0.4041 GLoss: 0.9524 Sum: 1.3565\n", + "[2023-04-06T19:10:50.537642+0200][45392][INFO] Step 300: MLoss: 0.1456 GLoss: 0.9186 Sum: 1.0642\n", + "[2023-04-06T19:10:59.461444+0200][45392][INFO] Step 400: MLoss: 0.0757 GLoss: 0.898 Sum: 0.9737\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -1322,30 +1291,28 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "3fcb9493", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -1355,7 +1322,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "id": "2ea981cd", "metadata": {}, "outputs": [ @@ -1365,7 +1332,7 @@ "array([3, 4, 5, 6, 7, 8, 9])" ] }, - "execution_count": 15, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -1377,7 +1344,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "id": "bbd33233", "metadata": {}, "outputs": [ @@ -1419,143 +1386,419 @@ " \n", " \n", " 0\n", - " 3.8\n", - " 1.10\n", + " 14.2\n", + " 0.08\n", + " 1.66\n", + " 0.6\n", + " 0.346\n", + " 289.0\n", + " 9.0\n", + " 0.98711\n", + " 2.72\n", + " 0.22\n", + " 14.2\n", + " 6\n", + " \n", + " \n", + " 1\n", + " 14.2\n", + " 0.08\n", " 0.00\n", + " 0.6\n", + " 0.346\n", + " 289.0\n", + " 440.0\n", + " 0.98711\n", + " 2.72\n", + " 1.08\n", + " 14.2\n", + " 8\n", + " \n", + " \n", + " 2\n", + " 3.8\n", + " 0.08\n", + " 1.66\n", " 65.8\n", - " 0.009000\n", + " 0.346\n", " 289.0\n", - " 50.104997\n", - " 1.038893\n", + " 9.0\n", + " 0.98711\n", " 3.82\n", - " 0.220000\n", + " 0.22\n", " 8.0\n", + " 7\n", + " \n", + " \n", + " 3\n", + " 3.8\n", + " 0.08\n", + " 0.00\n", + " 0.6\n", + " 0.346\n", + " 289.0\n", + " 9.0\n", + " 0.98711\n", + " 2.72\n", + " 0.22\n", + " 14.2\n", " 5\n", " \n", " \n", - " 1\n", + " 4\n", " 14.2\n", - " 0.08\n", + " 1.10\n", " 1.66\n", " 0.6\n", - " 0.251377\n", + " 0.346\n", " 289.0\n", - " 9.000000\n", - " 0.987291\n", - " 3.82\n", - " 1.080000\n", + " 440.0\n", + " 1.03898\n", + " 2.72\n", + " 1.08\n", + " 14.2\n", + " 6\n", + " \n", + " \n", + " 5\n", + " 14.2\n", + " 0.08\n", + " 1.66\n", + " 0.6\n", + " 0.009\n", + " 2.0\n", + " 9.0\n", + " 0.98711\n", + " 2.72\n", + " 1.08\n", " 8.0\n", " 6\n", " \n", " \n", - " 2\n", + " 6\n", " 3.8\n", " 1.10\n", + " 1.66\n", + " 65.8\n", + " 0.009\n", + " 289.0\n", + " 440.0\n", + " 1.03898\n", + " 2.72\n", + " 1.08\n", + " 8.0\n", + " 5\n", + " \n", + " \n", + "\n", + "" + ], + "text/plain": [ + " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", + "0 14.2 0.08 1.66 0.6 0.346 \n", + "1 14.2 0.08 0.00 0.6 0.346 \n", + "2 3.8 0.08 1.66 65.8 0.346 \n", + "3 3.8 0.08 0.00 0.6 0.346 \n", + "4 14.2 1.10 1.66 0.6 0.346 \n", + "5 14.2 0.08 1.66 0.6 0.009 \n", + "6 3.8 1.10 1.66 65.8 0.009 \n", + "\n", + " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", + "0 289.0 9.0 0.98711 2.72 0.22 \n", + "1 289.0 440.0 0.98711 2.72 1.08 \n", + "2 289.0 9.0 0.98711 3.82 0.22 \n", + "3 289.0 9.0 0.98711 2.72 0.22 \n", + "4 289.0 440.0 1.03898 2.72 1.08 \n", + "5 2.0 9.0 0.98711 2.72 1.08 \n", + "6 289.0 440.0 1.03898 2.72 1.08 \n", + "\n", + " alcohol quality \n", + "0 14.2 6 \n", + "1 14.2 8 \n", + "2 8.0 7 \n", + "3 14.2 5 \n", + "4 14.2 6 \n", + "5 8.0 6 \n", + "6 8.0 5 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin.generate(len(outcome), cond=outcome)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "ed7f9903", + "metadata": {}, + "source": [ + "Use an array as the `cond` argument of the `fit` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "8d90f2fa", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-04-06T19:23:00.914877+0200][45392][INFO] Encoding fixed acidity 8821222230854998919\n", + "[2023-04-06T19:23:00.923931+0200][45392][INFO] Encoding volatile acidity 3689048099044143611\n", + "[2023-04-06T19:23:00.934174+0200][45392][INFO] Encoding citric acid 735380040632581265\n", + "[2023-04-06T19:23:00.954557+0200][45392][INFO] Encoding residual sugar 2442409671939919968\n", + "[2023-04-06T19:23:00.965758+0200][45392][INFO] Encoding chlorides 7195838597182208600\n", + "[2023-04-06T19:23:00.976757+0200][45392][INFO] Encoding free sulfur dioxide 3309873879720413309\n", + "[2023-04-06T19:23:00.996365+0200][45392][INFO] Encoding total sulfur dioxide 8059822526963442530\n", + "[2023-04-06T19:23:01.005686+0200][45392][INFO] Encoding density 3625281346475756911\n", + "[2023-04-06T19:23:01.014352+0200][45392][INFO] Encoding pH 4552002723230490789\n", + "[2023-04-06T19:23:01.021350+0200][45392][INFO] Encoding sulphates 4957484118723629481\n", + "[2023-04-06T19:23:01.029350+0200][45392][INFO] Encoding alcohol 3711001505059098944\n", + "[2023-04-06T19:23:01.036351+0200][45392][INFO] Encoding quality 3457201635469827215\n", + "[2023-04-06T19:23:07.229567+0200][45392][INFO] Step 100: MLoss: 1.3287 GLoss: 0.9813 Sum: 2.31\n", + "[2023-04-06T19:23:13.660368+0200][45392][INFO] Step 200: MLoss: 1.2782 GLoss: 0.9404 Sum: 2.2186\n", + "[2023-04-06T19:23:21.260768+0200][45392][INFO] Step 300: MLoss: 1.2039 GLoss: 0.8899 Sum: 2.0938\n", + "[2023-04-06T19:23:29.299141+0200][45392][INFO] Step 400: MLoss: 1.1596 GLoss: 0.8612 Sum: 2.0208\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import random\n", + "from sklearn.preprocessing import LabelEncoder\n", + "cond = random.choices(['red', 'white', 'rose'], k=len(loader))\n", + "cond = LabelEncoder().fit_transform(cond)\n", + "plugin.fit(loader, cond=cond)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "8c07b5a8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGwCAYAAACKOz5MAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB+d0lEQVR4nO3dd3hT5d/H8XfSvfeEUsoqe0MpS5DKUJHhAEQZAiqCigwVB7gRXODjwJ+DoQKiMpShzLJnS1mFQqHQUrpo6Z5JzvNHaKBSoDuhfF/Xda6kJ+ec3HcD5MN97qFSFEVBCCGEEMKEqY1dACGEEEKIO5HAIoQQQgiTJ4FFCCGEECZPAosQQgghTJ4EFiGEEEKYPAksQgghhDB5EliEEEIIYfLMjV2AqqDT6bh8+TIODg6oVCpjF0cIIYQQZaAoCllZWfj6+qJW374NpVYElsuXL+Pn52fsYgghhBCiAuLi4qhbt+5tj6kVgcXBwQHQV9jR0dHIpRFCCCFEWWRmZuLn52f4Hr+dWhFYim8DOTo6SmARQggh7jJl6c4hnW6FEEIIYfIksAghhBDC5ElgEUIIIYTJqxV9WIQQQojK0Gq1FBUVGbsYtZKFhQVmZmaVvo4EFiGEEPcsRVFITEwkPT3d2EWp1ZydnfH29q7UXGkSWIQQQtyzisOKp6cntra2MvloFVMUhdzcXJKTkwHw8fGp8LUksAghhLgnabVaQ1hxc3MzdnFqLRsbGwCSk5Px9PSs8O0h6XQrhBDinlTcZ8XW1tbIJan9in/HleknJIFFCCHEPU1uA1W/qvgdS2ARQgghhMmTwCKEEEIIkyeBRQghhKjlQkNDUalUd/XwbQksd5BblMuJKyeMXQwhhBDinibDmm8jLiuOR1Y/goWZBbuH78bSzNLYRRJCCCHuSeVqYZkzZw6dOnXCwcEBT09PBg8eTFRU1G3PWbx4MSqVqsRmbW1d4hhFUZg1axY+Pj7Y2NgQEhLC2bNny1+bKlbXvi5OVk7kafI4mnLU2MURQghRzRRFIbdQY5RNUZQyl7NXr168+OKLTJkyBRcXF7y8vPj+++/Jyclh7NixODg40KhRIzZu3HjLa/z555+0aNECKysr6tevz2effVbi9W+++YbGjRtjbW2Nl5cXjz32mOG1P/74g1atWmFjY4ObmxshISHk5OSU/xdeDuVqYdmxYweTJk2iU6dOaDQa3njjDfr27UtkZCR2dna3PM/R0bFEsPnv8KZ58+bx5ZdfsmTJEgICAnj77bfp168fkZGRN4WbmqRSqQjyCWJDzAYOJBygk3cno5VFCCFE9csr0tJ81r9Gee/I9/pha1n2r+UlS5bw6quvcvDgQX777TcmTpzI6tWrGTJkCG+88QZffPEFTz/9NLGxsTedGxYWxhNPPME777zDsGHD2Lt3Ly+88AJubm6MGTOGw4cP89JLL/Hzzz/TtWtX0tLS2LVrFwAJCQmMGDGCefPmMWTIELKysti1a1e5AldFqJRKvENKSgqenp7s2LGDnj17lnrM4sWLmTJlyi07+iiKgq+vL9OmTWP69OkAZGRk4OXlxeLFixk+fPhN5xQUFFBQUGD4OTMzEz8/PzIyMnB0dKxodUq1+uxqZu2dRRuPNvzy4C9Vem0hhBDGk5+fT0xMDAEBAYb/HOcWau6KwNKrVy+0Wq0hRGi1WpycnBg6dChLly4F9MsO+Pj4sG/fPvLz8+nduzdXr17F2dmZkSNHkpKSwqZNmwzXfPXVV1m/fj0nT55k1apVjB07lkuXLuHg4FDivcPDw+nQoQMXLlzA39+/TOUt7XcN+u9vJyenMn1/V6oPS0ZGBgCurq63PS47Oxt/f390Oh3t27fno48+okWLFgDExMSQmJhISEiI4XgnJyeCgoLYt29fqYFlzpw5vPvuu5Upepl18ekCwIkrJ8guzMbe0r5G3lcIIUTNs7EwI/K9fkZ77/Jo3bq14bmZmRlubm60atXKsM/LywvQT4n/3zBw6tQpBg0aVGJft27dmD9/PlqtlgceeAB/f38aNGhA//796d+/P0OGDMHW1pY2bdrQp08fWrVqRb9+/ejbty+PPfYYLi4u5a1yuVR4lJBOp2PKlCl069aNli1b3vK4wMBAfvrpJ9auXcsvv/yCTqeja9euXLp0CdAnQLj+iy3m5eVleO2/Zs6cSUZGhmGLi4uraDXuyMfeh3oO9dAqWg4nHa629xFCCGF8KpUKW0tzo2zlnQ3WwsLiprLfuK/4ejqdrty/BwcHB8LDw1m+fDk+Pj7MmjWLNm3akJ6ejpmZGZs3b2bjxo00b96c//u//yMwMJCYmJhyv095VDiwTJo0iRMnTrBixYrbHhccHMyoUaNo27Yt9913H6tWrcLDw4Pvvvuuom+NlZUVjo6OJbbqVNzKciDhQLW+jxBCCFETmjVrxp49e0rs27NnD02aNDEsTmhubk5ISAjz5s3j2LFjXLhwgW3btgH6MNStWzfeffddjhw5gqWlJatXr67WMlfoltDkyZNZt24dO3fupG7duuU618LCgnbt2hEdHQ2At7c3AElJSSWWnU5KSqJt27YVKV6VC/IJYuWZlexP2G/sogghhBCVNm3aNDp16sT777/PsGHD2LdvH1999RXffPMNAOvWreP8+fP07NkTFxcXNmzYgE6nIzAwkAMHDrB161b69u2Lp6cnBw4cICUlhWbNmlVrmcvVwqIoCpMnT2b16tVs27aNgICAcr+hVqvl+PHjhnASEBCAt7c3W7duNRyTmZnJgQMHCA4OLvf1q0Nn786oUBGdHs2VvCvGLo4QQghRKe3bt2flypWsWLGCli1bMmvWLN577z3GjBkDgLOzM6tWreL++++nWbNmLFy4kOXLl9OiRQscHR3ZuXMnDz74IE2aNOGtt97is88+Y8CAAdVa5nKNEnrhhRdYtmwZa9euJTAw0LDfyckJGxsbAEaNGkWdOnWYM2cOAO+99x5dunShUaNGpKen88knn7BmzRrCwsJo3rw5AHPnzuXjjz8uMaz52LFjZR7WXJ5exhX1xN9PcCrtFHN6zOHhBg9Xy3sIIYSoObcauSKqXo2PEvr2228B/XCqGy1atMiQymJjY1GrrzfcXL16lQkTJpCYmIiLiwsdOnRg7969hrAC+qFUOTk5PPvss6Snp9O9e3f++ecfk/oD1MW3C6fSTnEg4YAEFiGEEKKGVWoeFlNREy0se+P38tyW5/C282bTo5vK3ZtbCCGEaZEWlppTFS0ssvhhGbXzaoeF2oLEnERis26eNVAIIYQQ1UcCSxnZmNvQ1rMtAPsvy2ghIYQQoiZJYCkHw3wsiTIfixBCCFGTJLCUQ5BPEKCfQE6r0xq5NEIIIcS9QwJLObRwa4G9hT2ZhZmcTjtt7OIIIYQQ9wwJLOVgrjano3dHAJn1VgghhKhBEljKqbgfiwQWIYQQpqx+/frMnz/f2MWoMhJYyqk4sBxJPkKBtsDIpRFCCCHuDRJYyqmBUwM8bDwo0BYQkRxh7OIIIYQQ9wQJLOWkUqlKjBYSQgghjCErK4uRI0diZ2eHj48PX3zxBb169WLKlCmlHh8bG8ugQYOwt7fH0dGRJ554gqSkJMPrR48epXfv3jg4OODo6EiHDh04fPgwABcvXmTgwIG4uLhgZ2dHixYt2LBhQ01U06BcawkJvS4+XVh3fh37E/bzEi8ZuzhCCCGqiqJAUa5x3tvCFsqx7MvUqVPZs2cPf/31F15eXsyaNYvw8HDatm1707E6nc4QVnbs2IFGo2HSpEkMGzaM0NBQAEaOHEm7du349ttvMTMzIyIiAgsLCwAmTZpEYWEhO3fuxM7OjsjISOzt7aui1mUmgaUCiltYTqaeJLMwE0fL6lm/SAghRA0ryoWPfI3z3m9cBku7Mh2alZXFkiVLWLZsGX369AH0CxH7+pZe9q1bt3L8+HFiYmLw8/MDYOnSpbRo0YJDhw7RqVMnYmNjmTFjBk2bNgWgcePGhvNjY2N59NFHadWqFQANGjSocDUrSm4JVYC3nTf1HeujU3QcSjxk7OIIIYS4x5w/f56ioiI6d+5s2Ofk5ERgYGCpx586dQo/Pz9DWAFo3rw5zs7OnDp1CtC32IwfP56QkBA+/vhjzp07Zzj2pZde4oMPPqBbt27Mnj2bY8eOVVPNbk1aWCooyCeIC5kXOJBwgD71+hi7OEIIIaqCha2+pcNY721E77zzDk8++STr169n48aNzJ49mxUrVjBkyBDGjx9Pv379WL9+PZs2bWLOnDl89tlnvPjiizVWPmlhqaBgn2BA5mMRQohaRaXS35YxxlaO/isNGjTAwsKCQ4eut/JnZGRw5syZUo9v1qwZcXFxxMXFGfZFRkaSnp5O8+bNDfuaNGnCK6+8wqZNmxg6dCiLFi0yvObn58fzzz/PqlWrmDZtGt9//315frOVJoGlgjp6d0StUhOTEUNSTtKdTxBCCCGqiIODA6NHj2bGjBls376dkydPMm7cONRqNapSgk9ISAitWrVi5MiRhIeHc/DgQUaNGsV9991Hx44dycvLY/LkyYSGhnLx4kX27NnDoUOHaNasGQBTpkzh33//JSYmhvDwcLZv3254raZIYKkgJysnmrvqU6ms3iyEEKKmff755wQHB/Pwww8TEhJCt27daNasGdbW1jcdq1KpWLt2LS4uLvTs2ZOQkBAaNGjAb7/9BoCZmRmpqamMGjWKJk2a8MQTTzBgwADeffddALRaLZMmTaJZs2b079+fJk2a8M0339RofVWKoig1+o7VIDMzEycnJzIyMnB0rLkRO/PD5vPjiR95pOEjfNj9wxp7XyGEEJWXn59PTEwMAQEBpX7J321ycnKoU6cOn332GePGjTN2cUq41e+6PN/f0sJSCV18r60rdHk/Fc19tSAvCiGEMIIjR46wfPlyzp07R3h4OCNHjgRg0KBBRi5Z9ZDAUgltPdpiqbYkOS+ZmMyYcp2bp8nj+c3P0/fPvhxJPlJNJRRCCFGbffrpp7Rp04aQkBBycnLYtWsX7u7uxi5WtZDAUgnW5ta082oH6FtZyqpIV8T0HdPZc3kPiTmJjP93PJsubKquYgohhKiF2rVrR1hYGNnZ2aSlpbF582bDxG61kQSWSipevbms6wopisK7e99l56WdWJlZ0dm7M4W6QqbvmM6Sk0vkFpEQQghRCgkslRTkrZ+m/1DiITQ6zR2PXxC+gLXn1mKmMuPT+z7lfw/8j+GBw1FQ+PTwp8w5OAetTlvdxRZCCCHuKhJYKqm5W3McLBzIKsriVOqp2x77c+TP/HjiRwBmB8+ml18vzNRmvBH0BtM7Tgdg+enlvBL6CnmavGovuylKzk1m3+V95Bpr8TEhhBAmSQJLJZmpzejk3Qm4/ay3G85vYN6heQC83P5lhjQeYnhNpVIxusVoPr3vUyzVlmyP2864f8eRmpdavYU3ARqdhrCkMOaHzeexvx6jz+99eHbzs/T/sz8/nfhJgosQQghAAkuVKB7efKt+LHvj9/LmnjcBGNlsJONalj4+vl/9fnzf93ucrJw4fuU4T214ipiM8o0+uhsk5yaz+uxqpoZOpeeKnoz5Zww/nviRqKtRqFDhYuXC1YKrfBH2BQNWDWDxicUSXIQQ4h4nix9WgSAffT+WI8lHyNfkY21+fVKcE1dOMCV0Chqdhv71+/Nqp1dLnTa5WHuv9vw84GcmbpnIpexLPL3xab7s/SXtvdpXez2qi0an4WjKUXbH72Z3/G5Op50u8bqTlRPdfLvRvU53utXphqOlIxtiNrDw6ELisuL4LOwzFp1cxDMtn+GJwCewMbcxUk2qTk5RDufSz9Hao7WxiyKEEHcFmem2CiiKQsgfISTnJvO/B/5HsK9+YcSLmRd5esPTXC24ShefLnzd52sszSzLdM3UvFRe3PYix68cx1JtyUc9PqJf/X7VWY1qkZSTxLhN47iYedGwT4WKFm4t6F63O93rdKelW0vM1GY3navRafj73N98d+w74rPjAXC3ceeZls/weJPHSwTDu4miKIzfNJ6DiQd5r+t7JW4PCiFqzt08022vXr1o27Yt8+fPN3ZRykRmujURKpXKMLy5uB9LSm4Kz21+jqsFV2nu1pz5veeXOawAuNm48WO/H+nt19sw7HnxicV31bDn3KJcXtz2IhczL+Jg6cCAgAF81P0jQoeFsvzh5UxqO4k2Hm1KDSsA5mpzhjQewt9D/ubdru/ia+fLlbwrzDs0jwdXPcivp36lQFtQw7WqvM0XN3Mw8SAA88Pnk1WYZeQSCSGE6ZPAUkVunI8lqzCLiVsmEp8dTz2HenzT5xvsLOzKfU0bcxu+6PUFTzZ9EoDPwj7jzd1vkp6fXpVFrxZanZbXdr7GqbRTuFq7svLhlczrOY+BDQfiau1armtZqC0Y2ngo64asY1bwLHzsfEjJS+Hjgx/z4KoH+evcX9VUi6pXoC3g87DPATBXmZOWn8b/jv3PyKUSQgjTJ4GlihT3Y4lMjeSFLS8QdTUKN2s3Fj6wEDcbtwpf10xtxuudX2dGxxmoUPH3+b95ZM0j/HXuL5Nubfn08KeEXgrFUm3Jgt4LqOtQt9LXtDCz4PEmj7NuyDre7vI2XrZeJOcm8+buNwlLCquCUle/nyN/Jj47Hk9bT+bdpx819supX0rcMhNCiPK4evUqo0aNwsXFBVtbWwYMGMDZs2cNr1+8eJGBAwfi4uKCnZ0dLVq0YMOGDYZzR44ciYeHBzY2NjRu3JhFixYZqyq3JYGlinjaetLAqQEKChEpEdhZ2LHwgYX4OfhV+toqlYpRLUaxdMBSGjk34mrBVd7c/SYTNk8wyS+6FadX8MupXwD4sMeHtPVsW6XXtzSz5InAJ9gwdAMPBjwI6FfONuUAB/rbhMWtKa90eIUH/B+gR50eaHQaPj30qZFLJ4QAfR+z3KJco2wV/TdszJgxHD58mL/++ot9+/ahKAoPPvggRUVFAEyaNImCggJ27tzJ8ePHmTt3Lvb29gC8/fbbREZGsnHjRk6dOsW3335rsmsRySihKhTkE8T5jPNYqC34sveXNHVtWqXXb+vZlpUPr2RJ5BIWHl3IgYQDDF07lGdbP8szLZ/BwsyiQtct0BZwPOU49Rzr4WnrWaky7rq0izkH5wDwUruX6F+/f6WudzuWZpZM6ziNbbHbiEiJIDQulN71elfb+1XWgvAF5GnyaO3e2hC0ZnSawb7L+wi9FMrey3vp6tvVyKUU4t6Wp8kjaFmQUd77wJMHsLWwLdc5Z8+e5a+//mLPnj107ar/9+PXX3/Fz8+PNWvW8PjjjxMbG8ujjz5qWGeoQYMGhvNjY2Np164dHTt2BKB+/fpVU5lqIC0sVWhks5F08+3G/N7z6ezTuVrew8LMgvGtxrP6kdUE+wRTqCvkq4iveOzvxwhPCi/zdbILs9kYs5FpodPouaInY/8dy0OrHuLXU7+iU3QVKltUWhTTd0xHp+gY1HAQ41uNr9B1ysPT1pORzfRLqn955EuTXdbgZOpJ1p5bC8BrnV9DrdL/1QtwCmB40+EAzDs4r0zLOwghRLFTp05hbm5OUND1kOXm5kZgYCCnTulnX3/ppZf44IMP6NatG7Nnz+bYsWOGYydOnMiKFSto27Ytr776Knv37q3xOpSVtLBUIX9HfxY+sLBG3svP0Y/vHviODTH6GXTPZ5xn9D+jebTxo7zS4RWcrJxuOic1L5XtcdvZGruVAwkHKNIVGV6zs7AjpyiHjw9+zNbYrbzX9b1y9TtJyU1h8rbJ5Gpy6ezdmdnBs28730xVeqbVM/x+5nei06P5+/zfDG40uEbet6wURWHuwbkAPNzg4ZvmXnm+zfOsO7+OcxnnWBm1kiebPWmMYgoh0A92OPBk2RazrY73rg7jx4+nX79+rF+/nk2bNjFnzhw+++wzXnzxRQYMGMDFixfZsGEDmzdvpk+fPkyaNIlPPzW929QyD0stkFGQwRdhX/Dn2T8BcLV25bVOrzEgYACXcy6z9eJWtsZuJSIlokTrSX3H+vSp14c+9frQ3K05f5z5g8/CPiNPk4etuS3TOk7j8SaP3zF45BblMvbfsUSmRlLfsT6/PPhLqYGpOi06sYjPwz7H286bdUPWYWVmVaPvfzv/xPzDjJ0zsDG34a/Bf+Ft533TMSujVvL+/vdxsnJi/ZD1Nf77E+JeVBvmYZk0aRJNmjQpcUsoNTUVPz8/li5dymOPPXbTuTNnzmT9+vUlWlqKfffdd8yYMYPMzMwqLa/MwyIA/Uyx73R9h0X9FhHgFEBafhqv7XqNvn/2pf+f/fnk8CeEJ4ejU3Q0d2vOi+1eZO2gtfw95G+mdJhCK49WmKnNGNZ0GH8O/JP2nu3J1eTy/v73eX7L8yTmJN7yvbU6LTN3zSQyNRIXKxe+6fONUb5sRzQdgZetF4k5iaw4vaLG3/9W8jX5hmHMY1uOLTWsAAxtPJTGLo3JKMjgm4hvarKIQoi7WOPGjRk0aBATJkxg9+7dHD16lKeeeoo6deowaNAgAKZMmcK///5LTEwM4eHhbN++nWbNmgEwa9Ys1q5dS3R0NCdPnmTdunWG10yNBJZapKN3R/4Y+AeT2k7CUm1JYk4iapWajl4deb3z62x6dBO/Pfwbz7Z+lgbODUq9hp+jHz/1+4kZHWdgZWbF3st7Gbp2KGui15Tag/2LsC/YFrcNC7UFC+5fgJ9j5UdFVYS1uTUvtH0BgO+Pf28yk7EtObmEhJwEvO28GdNizC2PM1eb81qn1wD4Leo3zqWfq6ESCiHudosWLaJDhw48/PDDBAcHoygKGzZswMJCPxBDq9UyadIkmjVrRv/+/WnSpAnffKP/j5GlpSUzZ86kdevW9OzZEzMzM1asMJ3/9N1IbgnVUvHZ8USlRdHWs225J2ordj7jPG/tfovjV44D0KtuL2YFz8LD1gO4fhsDYG6PuTzY4MGqKXwFaXQahv41lJiMGJ5t/SwvtnvRqOVJykli4JqB5Gnyyvz7eXnby2yL20ZX364sDFlYY/2AhLgX3c23hO42NX5LaM6cOXTq1AkHBwc8PT0ZPHgwUVFRtz3n+++/p0ePHri4uODi4kJISAgHDx4sccyYMWNQqVQltv79q2847L2gjn0d7q93f4XDCkADpwYsHbCUl9u/jLnanNBLoQz5awgbYzayJ34PHx34CIDJbScbPayAvpXipXYvAfoJ2q7kXTFqeb488iV5mjzaerRlQMCAMp0zveN0LNQW7L28l52XdlZzCYUQ4u5RrsCyY8cOJk2axP79+9m8eTNFRUX07duXnJycW54TGhrKiBEj2L59O/v27cPPz4++ffsSHx9f4rj+/fuTkJBg2JYvX16xGokqZa42Z3yr8fz28G80c21GRkEGr+58lclbJ6NVtDzS8BGebf2ssYtp0KdeH1q7tyZPk8fCozUzYqs0x1OOG5YMeK3za2VuKfFz9OPp5k8D8MnhTyjSFt3hDCGEuDeUK7D8888/jBkzhhYtWtCmTRsWL15MbGwsYWG3nhb9119/5YUXXqBt27Y0bdqUH374AZ1Ox9atW0scZ2Vlhbe3t2FzcXGpWI1EtWji0oRfH/qViW0mYq4yR6No6OjVkXeC3zGp2xYqlYopHaYA8OeZP4nNjK3xMiiKwtxD+mHMjzR8hJbuLct1/rOtn8XN2o2LmRdZdnpZdRRRCCHuOpXqdJuRkQGAq2vZbzvk5uZSVFR00zmhoaF4enoSGBjIxIkTSU1NveU1CgoKyMzMLLGJ6mehtuCFti+w/OHlTO0wlQX3L6jw7LrVqZN3J7rV6YZG0fDVka9q/P03xmzkaMpRbMxteLn9y+U+387CznDewqMLSc279d8FIYS4V1Q4sOh0OqZMmUK3bt1o2bLs/4N87bXX8PX1JSQkxLCvf//+LF26lK1btzJ37lx27NjBgAED0GpLn7V0zpw5ODk5GTY/P+OMTLlXNXVtytiWY3G0NN0OzlPaTwFg44WNRKZG1tj75mnyDMOYx7caX+GlDgY1GkQz12ZkF2XzVUTNhy4h7iW1YOyJyauK33GFA8ukSZM4ceJEuYY/ffzxx6xYsYLVq1eX6CU8fPhwHnnkEVq1asXgwYNZt24dhw4dIjQ0tNTrzJw5k4yMDMMWFxdX0WqIWqqpa1PDej0LwhfU2PsuPrGYpNwkfO18GdV8VIWvo1apeb3z64D+1tbptNNVVUQhxDXFw35zc3ONXJLar/h3XPw7r4gKTc0/efJk1q1bx86dO6lbt2zTt3/66ad8/PHHbNmyhdatW9/22AYNGuDu7k50dDR9+vS56XUrKyusrExnJlNhmia3m8ymi5vYe3kvBxIOEORTvQuaJeYk8tOJnwB4peMrWJtXbphke6/29K/fn38u/MPcg3P5qd9PJtVfSIi7nZmZGc7OziQnJwNga2srf8eqmKIo5ObmkpycjLOzM2ZmZhW+VrkCi6IovPjii6xevZrQ0FACAgLKdN68efP48MMP+ffffw0rQt7OpUuXSE1NxcfHpzzFE6IEPwc/Hm/yOMtPL2d+2HyWPbSsWv8xmh8+n3xtPu0929PPv1+VXHNqh6lsj9vO4aTDLI1cir+jv34pek0uOUU55GpyySvKI1ejX56+eJ+iKHTx7UK/+v3wc7g7bplmFGTw1ZGvcLF2obN3Z1p7tMbSzNLYxRK1nLe3fvbp4tAiqoezs7Phd11R5Zo47oUXXmDZsmWsXbuWwMBAw34nJydsbPSLNo0aNYo6deowZ84cAObOncusWbNYtmwZ3bp1M5xjb2+Pvb092dnZvPvuuzz66KN4e3tz7tw5Xn31VbKysjh+/HiZWlJk4jhxK1fyrvDgqgfJ0+Tx2X2f0bd+3yq9vk7RsffyXpafXs7OSztRoWL5w8tp4daiyt7jqyNf8d2x7yp8fgu3FvSv359+9fvhY2+a/wnQ6DQ8v+V5DiRcX3TO2syatp5tCfIJopN3J1q4tcBcLeu1iuqh1WopKpJpBKqDhYXFLVtWyvP9Xa7Acqv/nS5atIgxY8YA+gWZ6tevz+LFiwGoX78+Fy9evOmc2bNn884775CXl8fgwYM5cuQI6enp+Pr60rdvX95//328vLzKVC4JLOJ2vo74moVHF1LfsT6rB62uki+9jIIM1kav5beo34jNuj50ekKrCbzU/qVKX/9GuUW5TN8xncvZl7G1sMXW3BYbCxvsLOywNdf/XLzf1sIWG3Mbcoty2Ry7mUOJh0oseNnGow396/enb/2+ZeoQnFWYxZmrZ4hKi+LM1TOcTjtNfHY8z7d5npHNRlZZHT899ClLIpdgY25Djzo9OJx0mLT8tBLH2FnY0d6zPUE+QXT27kygayBqlawuIsTdrNoCi6mSwCJuJ7swmwdXPcjVgqvMCp7F400er/C1otKiWH56ORtiNpCnyQPAwcKBQY0GMbzpcPwd/auq2FXiSt4Vtlzcwr8X/iUsKQwF/V93FSpDH5kH/B/AxdqF+Kx4oq5G6bdrASU+O/6W1/6o+0cMbDiw0mXccH4Dr+3Sr6P06X2f0q9+PxRF4Vz6OQ4mHuRg4kEOJR4is7Dk9AWOlo4E+wYzruU4mrmZ5mJtQojbk8AixH/8EvkLcw/NxcPGg/VD12NjblPmc4t0RWy9uJXlp5cTnhxu2N/YpTEjmo7goYCHsLWwrY5iV6nk3GQ2X9xsmCemmFqlxsbchpyi0mes9rbzJtAlkCYuTQh0DSQsKYzlp5djrjLny/u/pEfdHhUu0+m00zy94WnytfmMaznOMOnff+kUHVFpUYYAE5YUVqK8A+oPYHK7ydRzrFfhsgghap4EFiH+o1BbyCNrHiE+Ox4fOx9crV2xs7C7abO3sMfWwhZ7C3vsLOw4e/Usv5/5nZS8FADMVeb08e/DiKYjaO/Z/q4dUZCQncCmi5v4J+YfTqSeAMBSbUlD54YEugYS6BJIoKs+pDhZOZU4V6fomLlrJhtiNmBjbsMPfX+gtcftR/6VJj0/neHrhxOfHU+3Ot34+v6vMVOXbQSBRqchMjWSZaeXseH8BhQUzFXmPNrkUZ5r/ZxhgU4hhGmTwCJEKf6J+YcZO2dU6Fx3G3ceb/I4jzV5rMKTwZmqy9mXyS3Kxd/JHwt12eZIKNIWMXnbZPZe3ouzlTNLBiyhgVODMr+nRqdh4paJ7E/YT137uqx4eMVNwaisotKiWBC+gF3xuwCwMbfhqWZPMablGJOe3FAIIYHF2MURJuxi5kVSclPIKcohpyiH7KJscotyyS7KNuy7cbO1sGVwo8GE1AsxyWUIjCm3KJdx/47jROoJfOx8+HnAz3jZla2j/GeHP2PxycXYmNvwy4O/0MSlSaXLczjxMPPD5xtudzlaOjK+1XhGNB1R6TlxhBDVQwKLEKJGpOWnMWrjKC5mXqSRcyOWDFhyx1aNjTEbeXXnq8D1TrZVRVEUQuNC+fLIl0SnRwPgaevJC21eYFCjQTIsWggTI4FFCFFj4rPjeXrD06TkpdDesz3fPfDdLVs0otKieGrDU3fsZFtZWp2WdefX8XXE1yTkJABQ37E+L7V/iZB6IXdt3yMhapvyfH/LJAZCiEqpY1+Hb0O+xd7CnvDkcF7b+Roaneam49Lz03l5+8vka/Pp5tuNF9u9WG1lMlObMajRINYNWcernV7FxcqFC5kXmBo6lac3Ps2R5CNV9l5ZhVn8eupXNsZsNPoieoqicOLKCQq1hUYthxDVQVpYhBBV4lDiIZ7f/DyFukIebfwos4NnG1oyqrKTbUVkF2az+ORilkYuNcyf06deH15u/zIBTmVbYuS/knKS+PXUr6w8s9IwxLqvf19md51ttM6+q8+uZtbeWQR5B7HwgYVyC0yYPLklJIQwii0XtzBtxzR0io7nWj/H5HaTAfj88OcsOrmoSjvZVkRKbgrfHP2GVWdXoVN0mKnMeKzJYzzf5nncbdzLdI1z6edYfHIx686vM7Qk+Tv6E58Vj0bR4Gvny9yec2nr2bYaa3IzRVEY+tdQQ9+dsS3HMrXD1BotgxDlJYFFCGE0K6NW8v7+9wF4M+hNnKycDJ1sP7nvE/rX72/M4gH60DE/fD6hcaGAfij02BZjGd1idKmTACqKwpHkIyw6sYjQS6GG/R29OjK25Vi61+nOySsneXXnq1zKvoSZyoxJbSfxTMtnyjy3TGUdTDjIuE3jMFebG4LU570+5wH/B2rk/YWoCAksQgij+vbot3wT8Q0qVFiaWVKgLeCZls/wSodXjF20Eg4nHubzsM85fuU4AG7WbrzQ9gWGNh6KudocnaJje9x2Fp1YZBgurUJFn3p9GNty7E0T5mUVZvH+/vfZGLMRgCCfID7q/lGNzN0zZfsUtsZuZVjgMKzMrFgauRRbc1uWP7y8XHPkCFGTJLAIIYxKURQ+PPAhv0X9BkA332583afsM9nWJEVR2HRxEwvCFxCXFQdAgFMADzd4mL/P/c2FzAuAfibgQY0GMar5KOo71b/t9dZEr2HOwTnkafJwsXLhg+4f0LNuz2qrw+XsywxYNQCdomPNoDX4O/ozYdMEDicdpoFTA5Y9tAw7C7tqe/+qcCnrEo5WjjLZ3z1GAosQwui0Oi2fHv6UuKw4Puz+YY12sq2IIm0RK8+sZOHRhaQXpBv2O1g6MDxwOE82e7LM/VwAzmec59UdrxJ1NQqAUc1HMaX9lGqZgHB+2Hx+PPEjQd5B/NDvB0C/8OWwv4eRnJfMA/4P8Nl9n5nscO5VZ1fxzt538LD1YOXDK3GzcTN2kUQNkcAihBAVlFWYxaITiziWcoz7/O5jaOOhFW6dKNAW8Pnhz1l2ehkAzd2aM6/nvCpd1Ttfk88DfzxAekE683vPp0+9PobXIpIjGPvvWDQ6DdM7Tmd0i9FV9r5V5ddTv/LxwY8NP3f17cq3Id+iVsmsG/cCmYdFCCEqyMHSgZfav8QP/X7g6eZPV+pWipWZFTODZrKg9wKcrJyITI3kib+f4O9zf1dZef+58A/pBen42PnQq26vEq+19WzLq530HZ6/CPuCQ4mHqux9q8IPx38whJVHGj6CtZk1ey/v5cfjPxq5ZMIUSWARQohqdn+9+/lj4B908OpAriaXN3a/wS+Rv1T6uoqisOyUvvVmWOCwUvsIDQ8czsAGA9EqWqbvmE5iTmKl37eyFEXhy/AvWRC+AIAX2rzAB90+4I2gNwD4KuIrwpLCjFlEYYIksAghRA3wtvPmx74/Mq7lOAC+PPJlpcPD0ZSjnEo7hZWZFY82frTUY1QqFW8Hv02gSyBp+WlM2zHNqDPhKorCvEPz+P749wBM6zCNiW0nolKpGNxoMA83eBidouPVna9yNf+q0copTI8EFiGEqCFmajNeav8S7TzbkafJY96heZW6XnHfmAEBA3C2dr7lcTbmNnzR6wscLB04lnKs0u9bUVqdlnf3vcsvp/StS28FvcWYlmMMr6tUKt7u8jb1HeuTnJvMG7vfQKfojFJWYXoksAghRA1Sq9S8GfQmZiozNl/czJ74PRW6TkpuCpsvbAbgyaZP3vF4P0c/Pu6h7y/yW9Rv/HXurwq9b0VpdBre2P0Gf579E7VKzQfdPmBY02E3HWdrYcun932KlZkVu+N3s+Tkkkq/965Lu9h7eW+lryOMSwKLEELUsEDXQEY0HQHARwc+okBbUO5r/HHmDzSKhnae7Wjm1qxM5/Ss25OJbSYC8N6+9ziddrrc71sRhdpCpoVOY0PMBsxV5szrOY9BjQbd8vhA10Be6/waAAvCFxCRHFGh980tyuXN3W/ywtYXeG7zc/wc+XOFriNMgwQWIYQwgkltJ+Fh40FsViyLTiwq17nFc8YAhuBTVs+3eZ7udbpToC1gyvYpZBRklOv88srT5PHStpfYFrcNS7Ul83vPp1/9fnc877HGjzGg/gC0ipZXd75a7nKeSz/Hk+ufLNGSNO/QPH44/kO56yBMgwQWIYQwAntLe2Z0mgHoh/cWz7JbFpsvbuZK3hU8bDwI8Q8p1/uqVWo+7vExdezrEJ8dz7TQaey6tIvUvNRyXacscopymLhlInsu78HG3IavQ77mPr/7ynSuSqViVvAs6jnUIyEngbf2vEVZpw1bG72WEetHcC7jHB42HvzU7ydeaPMCoG+x+SbimzJfS5gOmThOCCGMRFEUJmyawIHEA/Ss25Ov7v+qTLPRPr3haSJSInih7QuGWzzldSr1FE9vfLrE7SgvWy+auzWnuVtzWri1oLlb83LPOqvRaUgvSCclN4X397/P8SvHsbew55uQb2jn2a5C5Ry5YSRFuiJmdJzBqBajbnlsniaPjw58xJroNQB08enCnB5zDDMU/3D8B8NQ6mdaPsOU9lNMdvbfe4XMdCuEEHeJ8xnnefSvR9HoNCzovYD7691/2+MjUyMZtm4Y5mpzNj+2uVzLBfxXWFIYK6NWEpkaycXMiyjc/HVwY4hp7tYcgNS8VFLzUw2PaflppObpH6/mXy1xHScrJ7574DtauLWocDlXnF7Bhwc+xFxtztL+S2nl0eqmY86ln2P6julEp0ejVql5oc0LjG81/qa5aX6O/NkwSmpks5G81uk1CS1GJIFFCCHuIsVrAfna+bJm8BpszG1ueezbe95mTfQaHgx4kLk951ZZGXKKcjiVeorI1Egi0yKJTI3kQsaFUkPMnahQ4WLtQn3H+rzV5S0auzSuVNkURWHajmlsvriZOvZ1WDlwZYlFEv8+9zfv73+fPE0e7jbuzO0xl84+nW95vZVRK3l///sAPN7kcd7q8pYsBWAkEliEEOIukluUy+C1g0nISWBCqwm81P6lUo+7mn+VkN9DKNQV8vOAn2nr2bZay5VTlMPptNOcvHKSyLRIotKisFBb4Grjipu1G242brhZu+Fq7Wp47mbjhouVS5WvzJ1VmMUTfz/BpexL9KnXhy96fUG+Np85B+awOno1AEE+QXzc4+MytTqtiV7DrD2zUFB4pOEjvNf1PZNcTby2k8AihBB3ma2xW5myfQrmanNWPbKKAKeAm4758fiPzA+fT3O35qx4aMU9dyvj5JWTPLXxKTQ6DWNbjmXXpV1Ep0ejQsXEthN5ttWz5QodG85v4I3db6BVtAyoP4APe3yIhbrqV9MWtyaLHwohxF3mfr/76VGnBxqdhg8PfHjTKBaNTsNvUb8B+oni7rWwAtDCvQXTOkwDYNGJRUSnR+Nm7cb3fb9nYpuJ5W4hebDBg3x636eYq83ZeGEjM3bMoEhbVB1FF1VAWliEEMJExGXGMXjtYAp1hXzS8xP6B/Q3vLb14lamhE7BxcqFzY9vxsrMyoglNR5FUZgaOpUtsVsI8g7i455luwV0OzvidvBK6CsU6YroWbcnn/f6/Kbfr0anISE7gbisOGKzYonLijNsFmoL+tXvx0MNHsLbzrtSZbnXyC0hIYS4S3179Fu+ifgGTxtP/hryF3YWdgCM/3c8BxIPML7VeF5u/7KRS2lcWp2WM1fP0MSlSZX1O9l7eS8vb3uZfG0+QT5B9KjTg7isOC5lXSI2K5aE7AQ0iua211ChorN3ZwY2HEiIf4jhsxO3JoFFCCHuUgXaAoasHUJcVhyjmo9iRqcZRF+NZshfQ1Cr1Pwz9B987H2MXcxa6VDiISZtnUSeJq/U163MrKhrXxc/Bz/8HP30jw5+JOUk8ff5vwlLCjMca21mzf317ueRho8Q5BOEudq8pqpxV5HAIoQQd7Hd8buZuGUiZiozVg5cycqolfwW9Rsh9UL4ovcXxi5erXYs5Rg/HP8BSzNLQyAp3jxtPW87/PlS1iXWn1/P3+f/5mLmRcN+dxt3Hgx4kEcaPkKga2BNVOOuIYFFCCHucq9sf4UtsVto5d6K6PRo8jR5/NTvJzp5dzJ20cQdKIrC8SvH+fvc3/xz4R/SC9INrzV2aUwj50bYWdhhZ26HnYUdtha2+p+vbbbm1392s3Gr1beWJLAIIcRdLjEnkUfWPGK4PdHIuRGrHll1T44OupsVaYvYFb+LdefXERoXSpGufKOQLNQWPNn0SZ5t82yJyfJqCwksQghRC/x04ie+CNPfAnq7y9s8EfiEkUskKiOjIINd8btIy0sjR5NDblEuOUU55BRde6654XlRDtlF2YbA6mzlzMQ2E3k88PFaNVeMBBYhhKgFirRFjN80nszCTH598FdsLWyNXSRRgxRFYXf8bj49/CnnM84DUN+xPtM7Tqdn3Z61orVNAosQQghRS2h0Gv488ydfR3zN1YKrgH4ZghkdZ1SoE6+iKFzKukRcVhx+jn7Usa9jtLWUJLAIIYQQtUxWYRbfH/+eXyJ/oUhXhAoVQxoPYXLbyXjYetzyPI1OQ9TVKI4kHSE8OZwjyUe4knfF8LqNuQ0BTgE0cm5EQ+eGhkcfO59qDzISWIQQQoha6lLWJeaHz+ffC/8C+sAxruU4RrUYhY25DblFuRy/cpzw5HDCk8I5mnL0prllLNQW1LGvw+XsyxTqCkt9HxtzGxo6NSwRYrr4dMHCrOr60EhgEUIIIWq5iOQIPjn0CceuHAPAy9YLDxsPTqWdQqtoSxzrYOFAW8+2tPdqTzvPdrR0b4mVmRVanZa4rDjOpZ8jOj1a/5gRzYWMCzeNaDJTmXFw5EEszSyrrA7VFljmzJnDqlWrOH36NDY2NnTt2pW5c+cSGHj7e2i///47b7/9NhcuXKBx48bMnTuXBx980PC6oijMnj2b77//nvT0dLp168a3335L48aNy1QuCSxCCCHuRYqi8M+Ff/gi7AsSchIM+71svWjv1Z4Onh1o59WORs6NynV7R6PTEJsVqw8wV6OJTo+mQFvAV32+qtLyV1tg6d+/P8OHD6dTp05oNBreeOMNTpw4QWRkJHZ2pU9ss3fvXnr27MmcOXN4+OGHWbZsGXPnziU8PJyWLVsCMHfuXObMmcOSJUsICAjg7bff5vjx40RGRmJtbV2lFRZCCCFqm3xNPpsubkKtUtPBs8Nds3xDjd0SSklJwdPTkx07dtCzZ89Sjxk2bBg5OTmsW7fOsK9Lly60bduWhQsXoigKvr6+TJs2jenTpwOQkZGBl5cXixcvZvjw4XcshwQWIYQQ4u5Tnu/vSnX/zcjIAMDV1fWWx+zbt4+QkJAS+/r168e+ffsAiImJITExscQxTk5OBAUFGY75r4KCAjIzM0tsQgghhKi9KhxYdDodU6ZMoVu3boZbO6VJTEzEy8urxD4vLy8SExMNrxfvu9Ux/zVnzhycnJwMm5+fX0WrIYQQQoi7QIUDy6RJkzhx4gQrVqyoyvKUycyZM8nIyDBscXFxNV4GIYQQQtQc84qcNHnyZNatW8fOnTupW7fubY/19vYmKSmpxL6kpCS8vb0Nrxfv8/HxKXFM27ZtS72mlZUVVlZWFSm6EEIIIe5C5WphURSFyZMns3r1arZt20ZAQMAdzwkODmbr1q0l9m3evJng4GAAAgIC8Pb2LnFMZmYmBw4cMBwjhBBCiHtbuVpYJk2axLJly1i7di0ODg6GPiZOTk7Y2NgAMGrUKOrUqcOcOXMAePnll7nvvvv47LPPeOihh1ixYgWHDx/mf//7HwAqlYopU6bwwQcf0LhxY8OwZl9fXwYPHlyFVRVCCCHE3apcgeXbb78FoFevXiX2L1q0iDFjxgAQGxuLWn294aZr164sW7aMt956izfeeIPGjRuzZs2aEh11X331VXJycnj22WdJT0+ne/fu/PPPP2Wag0UIIYQQtZ9MzS+EEEIIo6ixeViEEEIIIWqCBBYhhBBCmDwJLEIIIYQweRJYhBBCCGHyJLAIIYQQwuRJYBFCCCGEyZPAIoQQQgiTJ4FFCCGEECZPAosQQgghTJ4EFiGEEEKYPAksQgghhDB5EliEEEIIYfIksAghhBDC5ElgEUIIIYTJk8AihBBCCJMngUUIIYQQJk8CixBCCCFMngQWIYQQQpg8CSxCCCGEMHkSWIQQQghh8iSwCCGEEMLkSWARQgghhMmTwCKEEEIIkyeBRQghhBAmTwKLEEIIIUyeBBYhhBBCmDwJLEIIIYQweRJYhBBCCGHyJLAIIYQQwuRJYBFCCCGEyZPAIoQQQgiTJ4FFCCGEECZPAosQQgghTJ4EFiGEEEKYPAksQgghhDB5EliEEEIIYfIksAghhBDC5ElgEUIIIYTJk8AihBBCCJMngUUIIYQQJq/cgWXnzp0MHDgQX19fVCoVa9asue3xY8aMQaVS3bS1aNHCcMw777xz0+tNmzYtd2WEEEIIUTuVO7Dk5OTQpk0bvv766zIdv2DBAhISEgxbXFwcrq6uPP744yWOa9GiRYnjdu/eXd6iCSGEEKKWMi/vCQMGDGDAgAFlPt7JyQknJyfDz2vWrOHq1auMHTu2ZEHMzfH29i7TNQsKCigoKDD8nJmZWebyCCGEEOLuU+N9WH788UdCQkLw9/cvsf/s2bP4+vrSoEEDRo4cSWxs7C2vMWfOHEMQcnJyws/Pr7qLLYQQQggjqtHAcvnyZTZu3Mj48eNL7A8KCmLx4sX8888/fPvtt8TExNCjRw+ysrJKvc7MmTPJyMgwbHFxcTVRfCGEEEIYSblvCVXGkiVLcHZ2ZvDgwSX233iLqXXr1gQFBeHv78/KlSsZN27cTdexsrLCysqquosrhBBCCBNRYy0siqLw008/8fTTT2NpaXnbY52dnWnSpAnR0dE1VDohhBBCmLIaCyw7duwgOjq61BaT/8rOzubcuXP4+PjUQMmEEEIIYerKHViys7OJiIggIiICgJiYGCIiIgydZGfOnMmoUaNuOu/HH38kKCiIli1b3vTa9OnT2bFjBxcuXGDv3r0MGTIEMzMzRowYUd7iCSGEEKIWKncflsOHD9O7d2/Dz1OnTgVg9OjRLF68mISEhJtG+GRkZPDnn3+yYMGCUq956dIlRowYQWpqKh4eHnTv3p39+/fj4eFR3uIJIYQQohZSKYqiGLsQlZWZmYmTkxMZGRk4OjoauzhCCCGEKIPyfH/LWkJCCCGEMHkSWIQQQghh8iSwCCGEEMLkSWARQgghhMmTwCKEEEIIkyeBRQghhBAmTwKLEEIIIUyeBBYhhBBCmDwJLEIIIYQweRJYhBBCCGHyJLAIIYQQwuRJYBFCCCGEyZPAIoQQQgiTJ4FFCCGEECZPAosQQgghTJ4EFiGEEEKYPAksQgghhDB5EliEEEIIYfIksAghhBDC5ElgEUIIIYTJk8AihBBCCJMngUUIIYQQJk8CixBCCCFMngQWIYQQQpg8CSxCCCGEMHkSWIQQQghh8iSwCCGEEMLkSWARQgghhMmTwCKEEEIIkyeBRQghhBAmTwKLEEIIIUyeBBYhhBBCmDwJLEIIIYQweRJYhBBCCGHyJLAIIYQQwuRJYBFCCCGEyZPAIoQQQgiTJ4FFCCGEECZPAosQQgghTF65A8vOnTsZOHAgvr6+qFQq1qxZc9vjQ0NDUalUN22JiYkljvv666+pX78+1tbWBAUFcfDgwfIWTQghhBC1VLkDS05ODm3atOHrr78u13lRUVEkJCQYNk9PT8Nrv/32G1OnTmX27NmEh4fTpk0b+vXrR3JycnmLJ4QQQohayLy8JwwYMIABAwaU+408PT1xdnYu9bXPP/+cCRMmMHbsWAAWLlzI+vXr+emnn3j99dfL/V5CCCGEqF1qrA9L27Zt8fHx4YEHHmDPnj2G/YWFhYSFhRESEnK9UGo1ISEh7Nu3r9RrFRQUkJmZWWITQgghRO1V7YHFx8eHhQsX8ueff/Lnn3/i5+dHr169CA8PB+DKlStotVq8vLxKnOfl5XVTP5dic+bMwcnJybD5+flVdzWEEEIIYUTlviVUXoGBgQQGBhp+7tq1K+fOneOLL77g559/rtA1Z86cydSpUw0/Z2ZmSmgRQggharFqDyyl6dy5M7t37wbA3d0dMzMzkpKSShyTlJSEt7d3qedbWVlhZWVV7eUUQgghhGkwyjwsERER+Pj4AGBpaUmHDh3YunWr4XWdTsfWrVsJDg42RvGEEEIIYWLK3cKSnZ1NdHS04eeYmBgiIiJwdXWlXr16zJw5k/j4eJYuXQrA/PnzCQgIoEWLFuTn5/PDDz+wbds2Nm3aZLjG1KlTGT16NB07dqRz587Mnz+fnJwcw6ghIYQQQtzbyh1YDh8+TO/evQ0/F/clGT16NIsXLyYhIYHY2FjD64WFhUybNo34+HhsbW1p3bo1W7ZsKXGNYcOGkZKSwqxZs0hMTKRt27b8888/N3XEFUIIIcS9SaUoimLsQlRWZmYmTk5OZGRk4OjoaOziCCGEEKIMyvP9LWsJCSGEEMLkSWARQgghhMmTwCKEEEIIkyeBRQghhBAmTwKLEEIIIUyeBBYhhBBCmDwJLEIIIYQweRJYhBBCCGHyJLAIIYQQwuRJYBFCCCGEyZPAIoQQQgiTJ4FFCCGEECZPAosQQgghTJ4EFiGEEEKYPAksQgghhDB5EliEEEIIYfIksAghhBDC5ElgEUIIIYTJk8AihBBCCJMngUUIIYQQJk8CixF9vimK8UsOcS4l29hFEUIIIUyaBBYjWRsRz5fbotlyKpmHv9zNb4diURTF2MUyqtxCDf+eTOSLzWc4fCGtVv8+Pv03ijbvbuJoXLqxiyKEEHcFlVILvhUyMzNxcnIiIyMDR0fHKr22TqdQqNVhbWFWZddMyMij3xc7yczX4OtkzeWMfAAeau3DR0Na4WRjUWXvZeoSM/LZejqJLZFJ7DmXSqFGZ3itjZ8z47sHMKClN+ZmtSdbn03Kot/8negU6ODvwh/PB6NSqYxdLCGEqHHl+f42r6Ey3ZXyi7RM//0oWfkafhzdsUq+NHU6hem/HyUzX0Obuk6sfD6Yn3Zf4LNNUaw/lkBEbDpfjmhLB3/XKqiB6VEUhZOXM9lyKomtp5I5Hp9R4nU/VxuaeTsSeiaFo3HpvLj8CHWcbRjTtT7DOvvhaH33h7mPNpxCd+2/CWEXr7I5Mom+LbyNWyghhDBx0sJyG6cTMxn89R7yi3Q83cWf9wa1qPT/hH/aHcN76yKxtlCz/qUeNPSwByAiLp2Xlh8hNi0XM7WKl/s0ZlLvRpip7/7/eRdotOw7l2oIKQnXWpQAVCpo5+dMn2ZePNDci8ae9qhUKq5kF/DL/ov8vO8iqTmFANhZmjGsUz3GdquPn6utsapTKbvPXuGpHw9grlYxoJUPfx+9TGNPeza+3KNWtSIJIURZlOf7WwLLHfxzIpGJv4ahKDDr4eY80z2gwtc6m5TFQ/+3m0KNjvcHteDp4PolXs/KL2LW2pOsPhIPQOcAV+YPa4uvs01lqmBUSZn5PLZwL3FpeYZ9NhZm9GjsTkgzL3o39cTDweqW5+cXaVkbEc8Pu2I4m6zvnKxWQf+W3ozr3oAO/i7VXoeqotUpPPx/uzmVkMmYrvV55YEm3PfJdtJzi5j3aGue6ORn7CIKIUSNksBSxf638xwfbTiNSgXfP92RkOZe5b5GoUbHkG/2cPJyJvc18WDx2E63bK1ZfeQSb60+QU6hFicbC+Y+2or+LX0qW40aV6jRMfx/+wiPTcfNzpJ+Lb0JaeZJ14bu5e4TpCgKO89e4Ydd59l19ophf1s/Z14OaUzvQM+qLn6V+/1wHDP+OIaDtTk7ZvTG1c6SH3ad54P1p/B2tCZ0Rq8q7SslhBCmrjzf39IGXQYTejRgROd6KAq8tOIIJ/7T76IsFmw9w8nLmTjbWvDJY61ve2tpSLu6bHi5B23qOpGRV8Tzv4Qzc9Vx8gq1lalGjXtv3UnCY9NxsDbnz4ld+WhIK+5v6lWhL2WVSsV9TTz4eVwQ/07pybCOfliaqYmIS+eZxYeIMPHRNnmFWj7dFAXAi/c3wtXOEoCnuvhTx9mGxMx8Fu+9YMQSCiGEaZPAUgYqlYr3BrWgR2N3cgu1jFtyiMQb+mHcyeELaXwbeg6AOUNa4elofcdz/N3s+P35rjx/X0NUKlh+MJaBX+3m+KXyhyVjWHk4jl/2x6JSwYLhbanvbldl1w70dmDuY63Z8/r9hDTzRFHgw/WRJj0M+vtd50nKLKCuiw2jbrgVaG1hxisPNAHgm+3RZOQWGamEQghh2iSwlJGFmZqvR7ansac9SZkFjFtyiJwCzR3Pyy7QMHXlUXQKDG1fhwGtyn5rx9JczesDmvLLuCA8HayITs7mka93M/33oyRllj0w3UmRVseWyKQqm8Du2KV03lpzAoApfZpwf9Py30IrCw8HK94f3BJrCzWHLlzl35OJ1fI+lZWclc/CHfrA+lr/pje1MA1pV4dALwcy8zV8syPaGEUUQgiTJ4GlHBytLfhpTCfc7S05eTmTl1ccQau7/f/qP1gXSWxaLnWcbXjnkRYVet9ujdzZ+HIPBrf1RVHgj7BL9PoklAVbzpJbeOfQdCt5hVqW7L1Ar09CGb/0MA99uYtNlfzST80u4PmfwyjU6OjT1JMX729UqevdiY+TDRN6NADg442nS8zjYiq+2HyW3EItbf2cebj1zYHVTK3itQGBACzac4HL6Xk3HSOEEPc6CSzl5Odqy/9GdcTKXM2WU8l8sD7ylsdujkxixaE4VCr47Ik2lZpDxM3eivnD27H6ha508Hchr0jLF1vOcP+nO/gz7BK6OwSnG2XkFfH19mi6z93G7L9OEp+eh6WZmvwiHc/9EsaPu2MqdHtFo9Xx0oojXM7IJ8Ddjs+HtUVdA8Oyn7uvIe72VlxIzeXn/Rer/f3KIyoxi98OxQLw1kPNbtl3qXegJ50DXCnU6Ji/5UxNFlEIIe4KElgqoH09Fz5/oi2g/x/x0n0XbjrmSnYBr/95DIBnezSgSwO3KnnvdvX0M6N+9WQ76rroO2tO+/0og7/Zw8GYtNuem5yZz5yNp+j28TY++TeK1JxC6rrY8P7gloTPeoAng/Qdi99fF8k7f51Eoy1fa8Unm6LYE52KraUZC5/qUGMz9tpbmTOtr74fyJdbz5KeW1gj71sWczbqJ4kb0NKbjvVvPRmgSqXi9QFNAX0L2pmkrJoqohBC3BUksFTQQ619mNFP34z/zl8n2R6VbHhNURRe//MYqTmFNPV2YOq1L9OqolKpeLi1L1um3sdr/Ztib2XOsUsZPPHdPib+EsbF1JwSx19MzeGN1cfpPm873+04T3aBhkAvB+YPa0vo9F483cUfeytzPhzckjce1H9pLtl3ked+DitTPx2A9ccS+G7HeQDmPdaaQG+HKq3znTzR0Y9ALwcy8or4v22m0Q9k19kUQqNSsDBT8Vr/pnc8vn09F/q38EanwLx/omqghEIIcfeQwFIJL/RqyOMd6qJTYPKv4ZxKyATgt0NxbDmVjKWZmi+GtcXKvHrm1rC2MGNir4aEzujFyKB6qFWw8UQiD3y+k482nCLsYhovLj9C709DWXYglkKNjg7+Lvw4uqO+T0y7OiVmV1WpVDzbsyHfjGyPlbmaraeTeeK7fXfs4HsmKYsZfxwF4NmeDXi4tW+11Pd2zNQq3nioGQBL913gwpWcO5xRvbQ6hQ/XnwLg6S71yzxKanq/QNQq2HIqicMXbt9iJoQQ9xKZOK6SCjU6Rv90kH3nU/F1smbBiHaM/ukguYVa3niwKc/2bFhjZYlKzOKD9ZElJlYr1ivQgxd6NaJzQNnWKAqPvcqEJYdJzSnEx8man8Z0opnPzb/bzPwiBn21h5grOXRt6MbSZzobdYr5UT8dZOeZFAa09ObbpzoYrRwrD8fx6h/HcLw2SZzLtXlXymLmqmMsPxhHR38Xfr+HFkbML9Ly876LWJipaOBhT4C7Hb7ONrVieQohROlkptsalpFbxJBv93A+JQeVChQFggJcWTahS43/Y6soCqFnUvhw/SnOp2TzUGtfJt7XkOa+5f+9xKbmMnbxQc6l5GBvZc7XI9tzXxMPw+s6ncKzPx9my6lk6jjb8NfkbrjZ33qa/ZoQlZjFgAX6lZB/fz6YTrfpN1Jdcgs19PoklOSsAt58sBkTejYo1/mJGfn0+nQ7+UU6vh/VkQcqMLPy3ag4qN3I0lxNfTdbAtztCHC3p4GHHQ3c7Qhwt8PVzvKeCXNlpdHqWH88ge6N3I3+d1GIspDAYgQXU3MY8s1e0nIKcbAyZ+OUHtR1Md4CfTqdQnahptKrG2fkFvHcL4fZfz4NM7V+Ar2RQf6AvoPr55vPYGmu5o/ng2ld17kKSl55xV98bfycWT2xa42MVLrRgi1n+WLLGfxcbdgy9b4K3RKc989pvgk9d88sjLg2Ip6XV0SgUulHTMWl5XIxNZfC23T8drQ2p6mPI4+1r8vANr7YWMqyBt/vPM+HG07RxMue1S90w87K3NhFEuK2JLAYyZHYq3y++QzPdA+4K9a2KatCjY7XVx1jVbh+UcbnejYgqIEr45YcRlH0nWyf6Gg6C/clZ+XT+5NQcgq1LBjelkFt69Tce2fm0+vTUHILtXz1ZLsK9+fJyCu6ZxZGPJ+SzcD/201OoZaX7m/E1L76zuxancLl9DzOpWQTcyXHsJ1PyeFyRh43/svlaG3O4x39GBlUjwbXVkC/1+h0Cr0+DSU2LRfQDwz4akQ7aYUSJk0Ci6hyiqLw5dZovrg2R0jxra+RQfX4cEgrI5fuZl9tO8unm85Qx9mGrdPuq7FFBV//8xgrDsXRrp4zqyZ2rdSXRfH/ln2crNk+vWwLI15Oz2NP9JVrfYrcCW7oVmW3JaOTs/jtUBy7zl5hQo8GPNqhbqWvmV+kZeg3e4lMyKRzgCvLxgeVqTUpv0jLhdQcQqNS+GX/RS5dvT7ZXo/G7jzVxZ8+TT1rfcvUjUKjkhmz6BB2lmYUanUUaZUa70dXUZfT85iyIoIAdzs+GNISi3voc7vXVevihzt37mTgwIH4+vqiUqlYs2bNbY9ftWoVDzzwAB4eHjg6OhIcHMy///5b4ph33nkHlUpVYmva9M7DQEXNUalUvBzSmC+GtcHCTIWiQLt6zsweWLHZe6vbuO4N8HGyJj49j0V7LtTIe55OzGTlYX0fjNtNEldWTwf74+tkTUJGPktusTDi1ZxCNhxP4M3Vx+n9aShdP97GjD+O8U3oOZ768QDdPt7GnI2niEqs2LwuuYUafj8cx2Pf7iXk8518vyuG04lZTP/jKH+GXapE7fQ+XH+KyIRMXO0s+XJ4uzIHDGsLM5p6O/L8fQ3ZMaM3i8Z04v6mnqhUsOvsFZ77OYwe87bz5dazJFfhMha3cjox86bpBGrarwf0ExQ+3tGPWQ83B/SzP++JvrkTvilJySrgqR8OcPBCGr8djmPG70fLNRGmuHeU+wZnTk4Obdq04ZlnnmHo0KF3PH7nzp088MADfPTRRzg7O7No0SIGDhzIgQMHaNeuneG4Fi1asGXLlusFM5d7r6ZoSLu61HO1Y3NkEuO6B2Bpbpr/E7KxNGNGv0CmrjzKN9ujebxjXdyrsROioih8tOE0OgUebOVNB//Kd/a1tjBjat9Apv9+lK+3RzO8Uz3MzVQcvJDG3ugr7IlO5VRiZolbI2oVtKrrTICbLdujUkjMzOe7Hef5bsd5Wvg6MqRdHR5p64unw60X4FQUhePxGaw4FMdfEZfJvjYXj5laRe9AT+yszFgbcZkZfxzF0lzNwDYVu+214XiCYWbiz59og7fTnRcFLY2ZWkXvpp70bqrv+/LrgVhWHo4jISOfzzef4cutZ+nX0punu/gTFOBapbdIMnKLmLPxFCsOxWGuVjGxV0Mm39+o2qYyuJXL6XlsPZUEwFNd6tHQw56IuAz+DL/E5GXh/P1id6P2qbuV9NxCnv7xAOev5ODpYEVaTiFrIi7jbGvJ7IHN5XaWKKFSt4RUKhWrV69m8ODB5TqvRYsWDBs2jFmzZgH6FpY1a9YQERFRoXLILSFRGp1OYdDXezgen8FTXerxweCqv3Wl1Sn8ezKR73ae52hcOhZmKrZMvQ9/t6pZnVqrU3hwwS6ikrL0Mxtn5KP5z/8+m3jZ07WhO90auRPUwNXQ0bpAo2X76WRWhcezPSqZIq3+PDO1ih6N3RnSrg59m3sbOqtm5BaxJiKeFYfiDHMKAfi72fJERz8e61AXL0drdDqFN9ccZ/nBOMzUKr5+sh39W5Z9UU/Qj0B76MtdZBVomNirYZkm1iuP/CItG08k8Mv+WMIuXjXsb1XHien9AunZ2L1SX4aKorDheCKz/zrJleyCEq818rRn7qOtqiS0ltXnm6L4cls0XRq4suLZYED/O3hs4V5OxGfSqo4Tvz8fXGO3RssiK7+Ip344wNFLGXg6WLHyuWCOXkpnym8RKAq83KexYSXzysou0KAoCg6VHIQgql55vr9rvBlDp9ORlZWFq2vJv8xnz57F19cXa2trgoODmTNnDvXq1Sv1GgUFBRQUXP9HIjMzs9TjxL1NrVbx5kPNGP6//Sw/GMeYrvVp5Fk1M/DmFWr5PSyOH3bFGDo5WpqrefPBZlUWVuD6wojPLD5s6KdR18WGbg3d6drIjeCGbrdsLbEyN6N/Sx/6t/Thak4h645dZtWReI7EphMapZ+F197KnP4tvdHqFDYcT6Dg2uKRluZqBrT0ZlgnP7oEuJUYaaVWq/hwcCsKNQp/hl/ixeVH+HakmpAyDr8u0GiZvDycrAINHfxdmFpFX0o3srYwY0i7ugxpV5eTlzP4ZX8sa47Eczw+g9E/HaRzgCuv9gu87XIJt3I5PY+315xg62n97NYNPeyYM7Q1qdkFvL32JNHJ2Ty2cB+jg+szo19gtY/UKdLqWHFIfyuyeAQf6H8HC5/qwMD/283x+AzeWnOCTx5rbRKtFnmFWsYtPszRSxm42Frw6/gg6rvbUd/djoy8ImatPcmCrWdxtrVgbLeASr3XumOXef3P41iYqfh1fJcKTfEgTEONt7DMmzePjz/+mNOnT+PpqR9Js3HjRrKzswkMDCQhIYF3332X+Ph4Tpw4gYPDzV8w77zzDu++++5N+6WFRZRmwtLDbI5M4v6mnvw0plOlrnUlu4Cl+y7y874LXM0tAsDZ1oJRXfwZ1bV+tdx2UhSFv45eJq9QS9eG7tRzq1zT/vmUbNYciWfVkfgSnVUBmno7MLyTH4Pb1cHZ9vaT3Wl1Cq/8FsFfRy9jaabmf6M60KsMo+Pe/fski/ZcwNnWgg0v9cDX2aZS9Smr1OwCvg09x9L9Fw2revcO9GBa30Ba1nG64/lancLP+y7wyb9R5BRqsTBTMbFXIyb1bmi4BZSeW8gH60/xx7X+PXWcbZgztBU9b5i/qKptPJ7AxF/Dcbe3Yu/r9990m3b32SuM+ukAOgXeH9ySp7v43+JKNaNAo2X8ksPsOnsFB2tzlk/octPvv3jKBNDfLhzavvwdvAs0Wj5cf4ql+64viOpqZ8mKZ7vQxKtmlw4Rt1Zjo4TKG1iWLVvGhAkTWLt2LSEhIbc8Lj09HX9/fz7//HPGjRt30+ultbD4+flJYBGlOp+STd8vdqLRKfw6PohujdwrdI0fdsfwZ9glQytEPVdbxvcI4LEOdbG1vPv6XOl0CocvXmXdscuogKHt69K6rlO5/geu0ep4cfkRNp5IxMpczU9jOt329/vvyUSe+zkMgB9Hd6RPs5qfFC8hI48vt0az8nAc2mu31x5q7cO0B5rcckj06cRMXv/zOBFx6QB08HdhztBWt/zi23kmhZmrjhOfrg+Ej3Woy1sPNbtjCKyIkT/sZ090KpN6N2RGv9JvrX234xxzNp7GwkzFime71OjtqhsVaXVM+jWcTZFJ2Fqa8fO4zqWWRVEU3l93ip/2xGCmVvHdUx3K3IIH+luOk5aFczw+A4Dn7mvAvnOpHLuUgbu9Fb8914WG9+jwd1NjkoFlxYoVPPPMM/z+++889NBDdzy+U6dOhISEMGfOnDseK31YxJ2889dJFu+9QDMfR9a92L3MQ33DLqbx3Y7zbD6VZOjc2qauE8/2bEj/lt4ybTz6eXpe+DWMLaeSsbEwY/HYTgSVsjr5pau5PLhgF5n5Gib0CODNh5obobTXXbiSwxdbzvDX0csoiv7222Pt6/JSSGPqXGv1yS/S8tW2aBbuOIdGp2BvZc5r/QMZGeR/xwkJcwo0fPJvFEv2XUBRwN3eivcHtWBAq/L197md8ynZ3P/ZDv3oqFd737JjraIoTF52hPXHE/B0sGLdi93xdKxYJ+eK0uoUpq6MYG3EZSzN1Sy6Q7jV6RRm/HGMP8MvYWmuZsnYzgQ3vPOq9/+cSGTGH0fJytfgYmvB58Pa0jvQk/TcQkZ8f4BTCZl4Oer7zFTl7VtRMSYXWJYvX84zzzzDihUrGDRo0B2vm52dTb169XjnnXd46aWX7ni8BBZxJ2k5hdz3yXay8jW0rOOIuVqNRqdDo1Uo1Oofi67NXaHR6SjS6CjSKYZbBwB9mnrybM8GdK7ikSa1QYFGy7NLw9hxJgU7SzOWjguig7+L4fUirY4nvtvHkdh02vg58/tzwSYzwuxUQiafbYpiyyl9nxRLMzUju9Sja0N3PtpwiphrC2k+0NyL9wa1wMepfLewwi6m8eofxziXor9O/xbevDeoRZUEhg/WRfLD7pgy3e7MKdAw+Os9nE3OpqO/C8smdKmxz0BRFN5Yre+oba5W8d3THcrUuqbR6pj4azibI5Owt9LfPmpVt/Tbd4UaHR9vPM1Pe2IAfSvY/41oV+KWY2p2ASO+38+ZpGzqONvw23NdTHL01L2kWgNLdnY20dHRALRr147PP/+c3r174+rqSr169Zg5cybx8fEsXboU0N8GGj16NAsWLCgxDNrGxgYnJ/0fvOnTpzNw4ED8/f25fPkys2fPJiIigsjISDw87nzvVwKLKIsfdp3ng2srKJeVpZmaIe3qMKFnQJV12K2t8ov0fRN2R1/BwcqcX8YH0cbPGYCPNpzifzvP42htzvqXeuDnanpfEmEXr/LJv6fZf77kKtmeDla8N6hFuUdC3ahAo+XrbdF8E6pvqXG0Nud/ozrSpZSWqLLKL9IS9NFWMvKK+GlMR+5veucAcD4lm0Ff7SGrQMPoYH/eHdSywu9fVjfe3lGr4MsR5ZsBOr9Iy9hFh9h3PhVXO0tWPhdMI8+St3MuXc1l8rIjhlt2z/ZswIx+gaVOQJeclc/w7/Zz/koO9Vxt+e25LuUOoaLqVGtgCQ0NpXfv3jftHz16NIsXL2bMmDFcuHCB0NBQAHr16sWOHTtueTzA8OHD2blzJ6mpqXh4eNC9e3c+/PBDGjYs2wyNElhEWeh0CjvOppBfqMXcTI25mQpLMzXmahUW5mos1Pp9FmZqLMxUmJupcbaxkPVYyiGvUMvoRQc5GJOGo7U5y5/tQlJmPs8sPgzAd093oF8LbyOX8tYURWFPdCqf/HuaY/EZjOhcj9f6N8XJpmqGw55KyOTVP45xPF4/lPffKT3LtZL3jf4Iu8T0349Sx9mGna/2LvPtyS2RSYxfqv88Pnu8TZXMWHw7xUOuAT55rDWPV2AZj6z8Ip78/gDH4zPwdbLm94ldDbfttp5KYurKo2TkFeFkY8Fnj7e5Y3+XxIx8hv1vHxdTc2ngbseKZ7vU+C0yoSdT8wshjCa7QMOoHw8QHpuOi60FCpCeW8SYrvV55xHTnBn5vxRFIadQi301hNW8Qi0P/d8uzqfkVGq9n8Ff7yEiLp0Z/QKZ1LtRuc4tnlDPylzNnxO7lhilU6jRkVeoJbdIQ06BVv+8UENukf65pZkaWysz7CzNsbMyw9bSHDtLc2ytzG5q0fg29Bxz/zkNwHuDWjAquH6561ksNbuAJ77bx7mUHBp42LF8Qhd+2h3DdzvPA9DGz5mvRrQrc+tdfHoeTyzcR3x6Ho097VnxbBdZ4doIJLAIIYwqM7+Ip69NCgb6Cdv+mBhc4zPAmqpjl9IZ+s1eNDqlQgt0nojP4OH/242FmYq9r/fBw6F8X7Q6ncK4JYfYfm0uHnsrc30oKdTeNDFheViaq7Gz1IcYG0szopOzAXitf1Mm9qr8mkaX0/N47Nu9XM7Ix8pcbRixN7ZbfWYOaFbuPjmxqbk88d0+EjPzaertwPIJXSrc4iUqRgKLEMLoMnKLmLD0MPHpeSybECQjMv6jeK4RB2tz/p3Ss1zz0cxcdZzlB2N5uLUPXz3ZvkLvn5FXxOCv9xg6Ff+XuVqF7bXwYWtpho2lGTYWZhTpFHIL9OEmp1BDboGWQq2u1GsATO7diOn9AitUxtKcS8nmiYX7SM0pxMHKnHmPta7UyKvzKdkM+99+UrIKaFnHkV/Hd7ntLUBFUbiYmkvYxascvniV8ItXKdLq6N7Ynd5NPQlu4GZSMwqbOgksQgiTodUpMvy7FBqtjscW7iMiLp2uDd34ZVzQHYdKg74/R9BHW8kt1LLi2S6V6ribkVdEVGIWNhZm2FqZ6QOKhb51pDytFcW3kXIKNeQW6m8l5RRqcLaxrJaZZaOTs/gzPJ7hnfyqJAifTcpi+P/2k5pTSFs/Z34e19kwjX9+kZYT8RklAkpqTuEtr2VtoaZbQ3d6NfXk/qaehr42onQSWIQQ4i4QcyWHBxfsIq9Iy9sPN2dc9ztPQ//zvgu8vfYkjTzt2fxKTxliX0VOJWQy4vv9pOcW0cHfhfb1nAm7eJUT8Zk3tSBZmqlpXddJf5y/C2qViu1RyWw/nUxCRsnVwQO9HOh9Lby0r+dc5hXJ7xUSWIQQ4i7xy/6LvLXmBJbmata/2J3Gt5k2XlEU+s/XL4Y5e2DzSq+zI0o6EZ/BiO/3k5WvKbHf3d6KDv7OdPR3pb2/Cy3rOJbaH0tRFE4nZrHttD68hMde5cYuQU42FvRs4kFLX0fqudri52pLPTdbw4Kl9yIJLEIIcZdQFIWxiw8RGpVCC19HVr/Q7Za3Yw5dSOPxhfuwsTBj/xt9qmy4tbjuaFw6n26Kop6rLR38Xejo74qfq02FWrKu5hSy82wK204ns+NMCunX1h/7Lxdbi+sB5obNz9UWX2ebWn1LVQKLEELcRZIz8+k3fydXc4tuuybQyyuOsDbiMsM6+jH3sdY1XEpRGVqdwpHYq+w6e4ULqTnEpuUSl5bLlexb94cBsLcy56FWPjzRqS7t67nUuluAEliqSlEenFoHySch5J2qu64QQvxH8arLahX8/nzwTYsCpmYXEDxnG4VaHX9N7kbrus7GKaioUtkFGuLScg0BJvaG7VJaXon+Mw087Hiiox9D29WpNRPdSWCpKmkx8GVbQAVTjoNz+WdoFEKIspq6MoJV4fHUc7Vl48s9SsyyvHDHOT7eeJrWdZ34a3J3I5ZS1JTiFdVXHo5j/bEE8oq0gH6hzt6BHjze0Y/7m3qWugTB7WTmF3HhSg7JmQW42Fni6WCFu70VNpY1PxxbAktVWvwwXNgFvd+E+16t2msLIcQNMvOLGDB/F/HpeYzo7MecofrbPjqdQq9PQ4lNy2Xeo615opP85+lek12gYf2xy6w8fImwi1cN+93sLBnSrg5PdPKjyQ0dtgs0WuLScjmXkkPMlRxirj2ev5LDleyCUt/DwcocdwcrPOyt8HC4vrnbW+qf21vTwtexTMPvy0oCS1U6ugJWPwfO/vBSBKhlSJoQovrsP5/KiO/3oyjww6iOhDT3IjQqmTGLDuFgbc6BN/pgaynrW93LopOz+SPsEn+GXyIl63r4aFPXCWdbS2Ku5HDpai63m7TYw8EKL0cr0nOLSMkqMMwafDtqFZz98MEq7QRcnu9v+VN/J80egfXTIf0iXNwDAT2MXSIhRC3WpYEb47sH8P2uGF5fdYx/6/Xk1wOxADzavq6EFUEjT3teH9CU6X2bsONMCisPx7H1VLJhKYxi9lbmBLjbGbYGHnY0cLenvrutYWI80I9UyyrQcCWrgJSsAlKyrz1mFXAl+/o+rQ6jjliSP/l3YmkLLYdC+BI48osEFiFEtZvWN5CdZ64QlZTFpGXhHIxJA+CpLvWMXDJhSszN1PRp5kWfZl5cyS7g35OJqFUqQzjxsLcq06gilUqFo7UFjtYWNPCwr4GSV4zc3yiLdk/rHyPXQn6mccsihKj1rC3M+GJYWyzMVOw/n4ZOgS4NXGnkeetJ5cS9zd3eipFB/ozoXI8uDdzwdLCudUOgJbCURd2O4N4ENHlwcpWxSyOEuAc093VkWt/riwaODPI3YmmEMD4JLGWhUkG7p/TPj/xi3LIIIe4ZE3o0YEi7OoQ086RfC29jF0cIo5JRQmWVlQSfNwNFC5MOgkfVLZcuhBBC3IvK8/0tLSxl5eAFjfvqn0srixBCCFGjJLCUR/FtoaMrQFv6IlZCCCGEqHoSWMqjST+w84CcZIjeYuzSCCGEEPcMCSzlYWYBrYfpn8ttISGEEKLGSGApr7Yj9Y9n/oHsFOOWRQghhLhHSGApL6/mUKcD6DRw7Ddjl0YIIYS4J0hgqYjiVpYjv8DdPypcCCGEMHkSWCqi5aNgbg0pp+ByuLFLI4QQQtR6ElgqwsZZv4ozSOdbIYQQogZIYKmodtduCx3/E4ryjFsWIYQQopaTwFJR9XuCUz0oyIBT64xdGiGEEKJWk8BSUWr19VaWIz8btyxCCCFELSeBpTLajNA/xuyEqxeNWxYhhBCiFpPAUhku/hBwH6DA0eXGLo0QQghRa0lgqax2T+sfj/wKOp1xyyKEEELUUhJYKqvZw2DlBBmxcGGXsUsjhBBC1EoSWCrLwgZaPap/LnOyCCGEENVCAktVaPuU/vHUX5CfYdyyCCGEELWQBJaqUKc9eDQDTT6c+NPYpRFCCCFqHQksVUGlgnbXWlnktpAQQghR5SSwVJXWw0BtDvFhkHzK2KURQgghahUJLFXF3gOa9Nc/X/IIrJ8GMbtApzVuuYQQQohawNzYBahV7nsN4g5ATjIc+kG/2XlC80eg+WDw7wpqM2OXUgghhLjrlLuFZefOnQwcOBBfX19UKhVr1qy54zmhoaG0b98eKysrGjVqxOLFi2865uuvv6Z+/fpYW1sTFBTEwYMHy1s04/NpDa9Ewsg/9SOHrJ2vh5clD8NnTWHdVP1U/tLyIoQQQpRZuQNLTk4Obdq04euvvy7T8TExMTz00EP07t2biIgIpkyZwvjx4/n3338Nx/z2229MnTqV2bNnEx4eTps2bejXrx/JycnlLZ7xmVtC4xAY/DVMP6sPL+1uCC+Hf4QlA+GzQH14ObNJvw6RzJIrhBBC3JJKURSlwierVKxevZrBgwff8pjXXnuN9evXc+LECcO+4cOHk56ezj///ANAUFAQnTp14quvvgJAp9Ph5+fHiy++yOuvv37TNQsKCigoKDD8nJmZiZ+fHxkZGTg6Ola0OtVLWwTnd0Dkaji1DvLTS75ubg2uDcCtEbg31j+6NQa3hmDrapQiCyGEENUpMzMTJyenMn1/V3sfln379hESElJiX79+/ZgyZQoAhYWFhIWFMXPmTMPrarWakJAQ9u3bV+o158yZw7vvvlttZa4WZhb6lpfGIfDwfIjZASfXQNxBSDuvn8MlOVK//Zet27Xw0gjsPcHa6T+bs/7RxhmsHPWtPNUp7yqkx+nLZe+pr5sQQghRjao9sCQmJuLl5VVin5eXF5mZmeTl5XH16lW0Wm2px5w+fbrUa86cOZOpU6cafi5uYblrmFlAoxD9BqDV6NciuhINqdGQelb/eCUasi5Dbqp+i9tftutb2OoDjK2bvoXGrTG4N9G33Lg3BiuHsl1Hp4P0C5B4AhKPQ9K1x4y4Gw5SXQsuXuDgpX+09wIHb32YsffWP7dy0A/7Vpvr61/8XKUqz29OCCHEPequHCVkZWWFlZWVsYtRdczM9beDXBsAfUu+VpANaefgyll9S0xumn76f8OWfv15Qab+nKJc/ZaVoA8Z/+Xgc+3W0w0hxq2RPhSVCCcnoDCr9DLbuOrfU9FC7hX9lnyy/HUvDi5qC/0IKjML/XMrB7B21LcYWTvqA1jxc6vilqVrr9u46EOTrZv+dymEEKLWqfZ/3b29vUlKSiqxLykpCUdHR2xsbDAzM8PMzKzUY7y9vau7eKbPyh582ui3O9Fp9aGlOMBkJelba66c0bfWXDmj7/iblaDfyrK6tJkVeDYF71bg1eraYwv97SedTh9yspMgO1H/ftk3bFk37C/KuUWZNfqN/JL7b5GT7sjaCWzdwc79eogxPL+238YVbF30j1aOoJbpiIQQwtRVe2AJDg5mw4YNJfZt3ryZ4OBgACwtLenQoQNbt241dN7V6XRs3bqVyZMnV3fxahe1mb61wcZF/7MP3NRik5d+7XbTtSCTevZ6642Voz6QeLcE79bg1VLf+nKrPipqtX7CPHsPoOXty6Yo+kCl04CuSP+o1ZT+s7YACrIgP/NaAMu83oJ042Px63lX9S1PKNfDWtq5sv3OVGb68GXjqu/cbON6rcXm2qO1k35FbnPrsj2aWVbuNldhLmTGQ3osZFzS337LuKTfrJ2gbkeo0wF825X91p4QQtQC5Q4s2dnZREdHG36OiYkhIiICV1dX6tWrx8yZM4mPj2fp0qUAPP/883z11Ve8+uqrPPPMM2zbto2VK1eyfv16wzWmTp3K6NGj6dixI507d2b+/Pnk5OQwduzYKqiiKMHGWf+lV7djyf3Fg8Wqq0+JSqW/XWNmDlhX/fV1Wn0Yy70COVeu9fu59piTev3nnCvXA05RzrVbWtdeT62Ccqgt9K1ilg76QGFlr3+0tL/23PHacwdQqfXhJCNO34k545K+jLdzep3+UaUGj6b68FKng/7z9Ggmt8SEELVWuYc1h4aG0rt375v2jx49msWLFzNmzBguXLhAaGhoiXNeeeUVIiMjqVu3Lm+//TZjxowpcf5XX33FJ598QmJiIm3btuXLL78kKCioTGUqz7AoIQw0BfrgkncV8tKuPb/x8SoUZEBRvn4UV1HerR+p8OwAN7O0Byc/cKoLzn7654519LfX4sPgUhhkXrr5PAtb8GkLdTuAdxt9SNQWgbbw2lakb7367z5Ngb41ybedfuVxR9+qq4sQQtxGeb6/KzUPi6mQwCKMSlH0X/5FeVCYo7+dVZitv11VkH3tedb1rfhnbZE+HDjXKxlQrJ3v3NKVlQiXDusDTPxhiD9y6w7S5WXvrQ8uvu2hTjv9o8wFJISoBhJYhLjX6LT6PknxYfogc+WM/raRmeW1zQLMra4/N+y/9nNmvD70pJwCpZRZl13qXwsw7fUtMY6++o7MVo4yNF0IUWESWIQQFVOYAwnH4HI4xIfD5SO378CsNr/WYfnaiCzbG4aYF++3cwc7D/28PHYeMtGgEMLApGa6FULcRSztwD9YvxXLuwqXI66HmMTj+s7LRTn6UV05yfqtrGxcr00q6Klfzfy/zx189C04Ni7SeiOEMJDAIoS4PRsXaNhbv92oKP9a5+TiUVhp1x9v3J+TAtkp+kdFq38tLw1SSp/J2sDc+np4KX4s3hx8wdFH399GRkYJcU+Qv+lCiIqxsAYL37KPKtLp9EEl+1qLTHbxlnQt1Fx7npWgDzqafLgao99uRaXWhxfDiKq61zow+13/WearEaJWkMAihKgZavW1/izuQPPbH6sp0AeXzMvXtxt/Lp6tWafRD/HOvHTrtbasna+PwrJ2uqHTscV/OiBfe66+tt/STj9xokczsLSt6t+GEKKcJLAIIUyPuZV+ZJJL/Vsfo9PpW2oyLt08M3B6nP55fvr1Lel4BQujAtcA8GyuX5ai+NG1gX52aSFEjZBRQkKI2qsg6/rSBumx+lFQ2sJrS0DcOKFe8fMb9udnQPKpW88+bG4NHoHg2QI8m+lDzY3rWFk7yzpVQtyBjBISQgjQ91/xbKbfKio7GZJOQnKkfkuK1HcYLsqFhKP6rTQqs5KLb9q5Xw809p76tbq8W+nXoBJC3JEEFiGEuJ3iYdc3jpLS6fSdgYsDTHKkfvK94nWsCjL1I6LuNORbba6/xVSnw/XZhT2aysgnIUoht4SEEKKqaQquDem+cn3RzRufZ16GhAj96Kj/srAFnzbXZxau0x5cAmROGlEryS0hIYQwJnOr63PG3Iqi6PvWxIfdMLNwhH5NqNh9+s1wPetrE+t56GcLtnPX/2yYQfiGn21dpTOwqJUksAghhDGoVPq5Ypz9oMVg/T6dDlLP6sNLcZBJPK6fkyYjVr/d+cL6Ydlqc/2mUl9/rjbTbyqz6/tsnMEvCPy7Qb0g/dBvIUyQ3BISQghTpim83j8mJ+Vav5gbZg++cctNAyrxT7pKre8I7N/t2ta17Ct1K4r+/a/GQNq1Cf9y0651MvbSbw7XHm3dpZ+OAGTxQ2MXRwghjEOr0c8mXDx0W6fVb4q25M86zfV9GZfg4h64uBfSzt98Tc/m+uDi3xXqddWfczVGf2xxMEmLgasX9J2Ny0RVMsjYe+lvbXk2hwb3gYN3Vf5WhAmTwCKEEKL8Mi/rg8vFvfoQc6f1nkrj4KPvJOwaoB/OnZumX3KheMtJAUV3+2u4B0KDXvrwUr+73KaqxSSwCCGEqLycKyUDTOJxfb8XF/9rMxFfCybFj87+d17GQKfVj6AqDjBZSdfXkIo7eG1emxu+llRq8G2nDzAB9+n721hYV099dTp9p+fCHP3Ef7IkQ7WTwCKEEKLqFebqR0BV5yik3DS4sBvOh0LMDkiNLvm6ufW1TsJd9es9obo25PvasO/i5yUe0Q81z8/Q37bKz7hhu+HngkxKhCULu2sjs67NxWMYleVxw/Nro7SsHKtuZmOdTr+cRG6q/haca0Mwt6yaa5sYCSxCCCFqh4xLcH6HPryc3wHZidX/nir1nW9blXaOjcsNm6v+0da15H5rZyjMLn1+HsPcPan6PkbF1Ob622TeLfXrWHm10M+UbO9V8fl5CnP0S1dYOxl1tmUJLEIIIWofRYGUKH14STh6rfOwAih3eES/EreNs74lxNrp2lb83Fn/WPyauZX+yzwnRb80Q4mRWcnX9qVcfyzMrp76Wl37PrtVZ2ZbN31w8SoOMs31Q9ZzUm4YVXaL55q869exsNUHLFsX/TVtXPVB68bnxY8+bat0jSwJLEIIIURNKcqHvKvXtrTrz3PT/rMvXb9Z2d+8vpStO9i5lVxA09zq+gSDSSeubSch8QSknSt/K9BNVJRrGLxKDW+nGi2wyEB4IYQQojIsrMHCBxx9qv7aN04wGDjg+v7CXP0orqST14NMcqS+hcUwG7JHKc9v+NnSTt96k5umD1W5aTc8T/3P86uAYtQVyCWwCCGEEHcbS9vra01VRvHtMQKqpFjVyXhRSQghhBCijCSwCCGEEMLkSWARQgghhMmTwCKEEEIIkyeBRQghhBAmTwKLEEIIIUyeBBYhhBBCmDwJLEIIIYQweRJYhBBCCGHyJLAIIYQQwuRJYBFCCCGEyZPAIoQQQgiTJ4FFCCGEECZPAosQQgghTJ65sQtQFRRFASAzM9PIJRFCCCFEWRV/bxd/j99OrQgsWVlZAPj5+Rm5JEIIIYQor6ysLJycnG57jEopS6wxcTqdjsuXL+Pg4IBKpbrtsZmZmfj5+REXF4ejo2MNlbDmST1rF6ln7XIv1PNeqCNIPStLURSysrLw9fVFrb59L5Va0cKiVqupW7duuc5xdHSs1X+4ikk9axepZ+1yL9TzXqgjSD0r404tK8Wk060QQgghTJ4EFiGEEEKYvHsusFhZWTF79mysrKyMXZRqJfWsXaSetcu9UM97oY4g9axJtaLTrRBCCCFqt3uuhUUIIYQQdx8JLEIIIYQweRJYhBBCCGHyJLAIIYQQwuTdc4Hl66+/pn79+lhbWxMUFMTBgweNXaRKeeedd1CpVCW2pk2bGl7Pz89n0qRJuLm5YW9vz6OPPkpSUpIRS3xnO3fuZODAgfj6+qJSqVizZk2J1xVFYdasWfj4+GBjY0NISAhnz54tcUxaWhojR47E0dERZ2dnxo0bR3Z2dg3W4s7uVM8xY8bc9Nn279+/xDF3Qz3nzJlDp06dcHBwwNPTk8GDBxMVFVXimLL8OY2NjeWhhx7C1tYWT09PZsyYgUajqcmq3FJZ6tirV6+bPs/nn3++xDGmXEeAb7/9ltatWxsmDwsODmbjxo2G1+/2z7HYnepZGz7L//r4449RqVRMmTLFsM/kPk/lHrJixQrF0tJS+emnn5STJ08qEyZMUJydnZWkpCRjF63CZs+erbRo0UJJSEgwbCkpKYbXn3/+ecXPz0/ZunWrcvjwYaVLly5K165djVjiO9uwYYPy5ptvKqtWrVIAZfXq1SVe//jjjxUnJydlzZo1ytGjR5VHHnlECQgIUPLy8gzH9O/fX2nTpo2yf/9+ZdeuXUqjRo2UESNG1HBNbu9O9Rw9erTSv3//Ep9tWlpaiWPuhnr269dPWbRokXLixAklIiJCefDBB5V69eop2dnZhmPu9OdUo9EoLVu2VEJCQpQjR44oGzZsUNzd3ZWZM2cao0o3KUsd77vvPmXChAklPs+MjAzD66ZeR0VRlL/++ktZv369cubMGSUqKkp54403FAsLC+XEiROKotz9n2OxO9WzNnyWNzp48KBSv359pXXr1srLL79s2G9qn+c9FVg6d+6sTJo0yfCzVqtVfH19lTlz5hixVJUze/ZspU2bNqW+lp6erlhYWCi///67Yd+pU6cUQNm3b18NlbBy/vtFrtPpFG9vb+WTTz4x7EtPT1esrKyU5cuXK4qiKJGRkQqgHDp0yHDMxo0bFZVKpcTHx9dY2cvjVoFl0KBBtzznbqynoihKcnKyAig7duxQFKVsf043bNigqNVqJTEx0XDMt99+qzg6OioFBQU1W4Ey+G8dFUX/JXfjl8F/3W11LObi4qL88MMPtfJzvFFxPRWldn2WWVlZSuPGjZXNmzeXqJcpfp73zC2hwsJCwsLCCAkJMexTq9WEhISwb98+I5as8s6ePYuvry8NGjRg5MiRxMbGAhAWFkZRUVGJOjdt2pR69erdtXWOiYkhMTGxRJ2cnJwICgoy1Gnfvn04OzvTsWNHwzEhISGo1WoOHDhQ42WujNDQUDw9PQkMDGTixImkpqYaXrtb65mRkQGAq6srULY/p/v27aNVq1Z4eXkZjunXrx+ZmZmcPHmyBktfNv+tY7Fff/0Vd3d3WrZsycyZM8nNzTW8drfVUavVsmLFCnJycggODq6VnyPcXM9iteWznDRpEg899FCJzw1M8+9lrVj8sCyuXLmCVqst8YsF8PLy4vTp00YqVeUFBQWxePFiAgMDSUhI4N1336VHjx6cOHGCxMRELC0tcXZ2LnGOl5cXiYmJxilwJRWXu7TPsfi1xMREPD09S7xubm6Oq6vrXVXv/v37M3ToUAICAjh37hxvvPEGAwYMYN++fZiZmd2V9dTpdEyZMoVu3brRsmVLgDL9OU1MTCz1My9+zZSUVkeAJ598En9/f3x9fTl27BivvfYaUVFRrFq1Crh76nj8+HGCg4PJz8/H3t6e1atX07x5cyIiImrV53irekLt+SxXrFhBeHg4hw4duuk1U/x7ec8EltpqwIABhuetW7cmKCgIf39/Vq5ciY2NjRFLJipr+PDhhuetWrWidevWNGzYkNDQUPr06WPEklXcpEmTOHHiBLt37zZ2UarNrer47LPPGp63atUKHx8f+vTpw7lz52jYsGFNF7PCAgMDiYiIICMjgz/++IPRo0ezY8cOYxeryt2qns2bN68Vn2VcXBwvv/wymzdvxtra2tjFKZN75paQu7s7ZmZmN/VwTkpKwtvb20ilqnrOzs40adKE6OhovL29KSwsJD09vcQxd3Odi8t9u8/R29ub5OTkEq9rNBrS0tLu2noDNGjQAHd3d6Kjo4G7r56TJ09m3bp1bN++nbp16xr2l+XPqbe3d6mfefFrpuJWdSxNUFAQQInP826oo6WlJY0aNaJDhw7MmTOHNm3asGDBglr1OcKt61mau/GzDAsLIzk5mfbt22Nubo65uTk7duzgyy+/xNzcHC8vL5P7PO+ZwGJpaUmHDh3YunWrYZ9Op2Pr1q0l7kve7bKzszl37hw+Pj506NABCwuLEnWOiooiNjb2rq1zQEAA3t7eJeqUmZnJgQMHDHUKDg4mPT2dsLAwwzHbtm1Dp9MZ/mG5G126dInU1FR8fHyAu6eeiqIwefJkVq9ezbZt2wgICCjxeln+nAYHB3P8+PESAW3z5s04OjoamumN6U51LE1ERARAic/TlOt4KzqdjoKCglrxOd5OcT1Lczd+ln369OH48eNEREQYto4dOzJy5EjDc5P7PKu8G68JW7FihWJlZaUsXrxYiYyMVJ599lnF2dm5RA/nu820adOU0NBQJSYmRtmzZ48SEhKiuLu7K8nJyYqi6Iel1atXT9m2bZty+PBhJTg4WAkODjZyqW8vKytLOXLkiHLkyBEFUD7//HPlyJEjysWLFxVF0Q9rdnZ2VtauXascO3ZMGTRoUKnDmtu1a6ccOHBA2b17t9K4cWOTG+57u3pmZWUp06dPV/bt26fExMQoW7ZsUdq3b680btxYyc/PN1zjbqjnxIkTFScnJyU0NLTEMNDc3FzDMXf6c1o8fLJv375KRESE8s8//ygeHh4mM0z0TnWMjo5W3nvvPeXw4cNKTEyMsnbtWqVBgwZKz549Ddcw9ToqiqK8/vrryo4dO5SYmBjl2LFjyuuvv66oVCpl06ZNiqLc/Z9jsdvVs7Z8lqX57+gnU/s876nAoiiK8n//939KvXr1FEtLS6Vz587K/v37jV2kShk2bJji4+OjWFpaKnXq1FGGDRumREdHG17Py8tTXnjhBcXFxUWxtbVVhgwZoiQkJBixxHe2fft2BbhpGz16tKIo+qHNb7/9tuLl5aVYWVkpffr0UaKiokpcIzU1VRkxYoRib2+vODo6KmPHjlWysrKMUJtbu109c3Nzlb59+yoeHh6KhYWF4u/vr0yYMOGmcH031LO0OgLKokWLDMeU5c/phQsXlAEDBig2NjaKu7u7Mm3aNKWoqKiGa1O6O9UxNjZW6dmzp+Lq6qpYWVkpjRo1UmbMmFFi7g5FMe06KoqiPPPMM4q/v79iaWmpeHh4KH369DGEFUW5+z/HYrerZ235LEvz38Biap+nSlEUperbbYQQQgghqs4904dFCCGEEHcvCSxCCCGEMHkSWIQQQghh8iSwCCGEEMLkSWARQgghhMmTwCKEEEIIkyeBRQghhBAmTwKLEEIIIUyeBBYhhBBCmDwJLEIIkzBmzBgGDx5s7GIIIUyUBBYhhBBCmDwJLEKIGvXHH3/QqlUrbGxscHNzIyQkhBkzZrBkyRLWrl2LSqVCpVIRGhoKQFxcHE888QTOzs64uroyaNAgLly4YLheccvMu+++i4eHB46Ojjz//PMUFhYap4JCiGphbuwCCCHuHQkJCYwYMYJ58+YxZMgQsrKy2LVrF6NGjSI2NpbMzEwWLVoEgKurK0VFRfTr14/g4GB27dqFubk5H3zwAf379+fYsWNYWloCsHXrVqytrQkNDeXChQuMHTsWNzc3PvzwQ2NWVwhRhSSwCCFqTEJCAhqNhqFDh+Lv7w9Aq1atALCxsaGgoABvb2/D8b/88gs6nY4ffvgBlUoFwKJFi3B2diY0NJS+ffsCYGlpyU8//YStrS0tWrTgvffeY8aMGbz//vuo1dKQLERtIH+ThRA1pk2bNvTp04dWrVrx+OOP8/3333P16tVbHn/06FGio6NxcHDA3t4ee3t7XF1dyc/P59y5cyWua2tra/g5ODiY7Oxs4uLiqrU+QoiaIy0sQogaY2ZmxubNm9m7dy+bNm3i//7v/3jzzTc5cOBAqcdnZ2fToUMHfv3115te8/DwqO7iCiFMiAQWIUSNUqlUdOvWjW7dujFr1iz8/f1ZvXo1lpaWaLXaEse2b9+e3377DU9PTxwdHW95zaNHj5KXl4eNjQ0A+/fvx97eHj8/v2qtixCi5sgtISFEjTlw4AAfffQRhw8fJjY2llWrVpGSkkKzZs2oX78+x44dIyoqiitXrlBUVMTIkSNxd3dn0KBB7Nq1i5iYGEJDQ3nppZe4dOmS4bqFhYWMGzeOyMhINmzYwOzZs5k8ebL0XxGiFpEWFiFEjXF0dGTnzp3Mnz+fzMxM/P39+eyzzxgwYAAdO3YkNDSUjh07kp2dzfbt2+nVqxc7d+7ktddeY+jQoWRlZVGnTh369OlTosWlT58+NG7cmJ49e1JQUMCIESN45513jFdRIUSVUymKohi7EEIIUVFjxowhPT2dNWvWGLsoQohqJO2lQgghhDB5EliEEEIIYfLklpAQQgghTJ60sAghhBDC5ElgEUIIIYTJk8AihBBCCJMngUUIIYQQJk8CixBCCCFMngQWIYQQQpg8CSxCCCGEMHkSWIQQQghh8v4f9CmIIxzBROMAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plugin.loss_history.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "75ecc282", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
03.80.080.000.60.0090000.3462.09.0000000.987110440.01.038983.821.0800001.0814.27
33.80.08114.21.101.660.60.0090002.09.0000000.98711065.80.346289.0440.00.987113.821.08000014.260.228.05
423.81.101.660.60.00900065.80.346289.0440.0000000.9871103.821.080000440.00.987112.720.2214.26
533.81.100.000.60.00900065.80.0092.09.0000000.987110440.00.987112.721.088.07
43.81.101.6665.80.009289.09.00.987113.821.0799750.228.067
6514.20.081.101.660.60.00900065.80.346289.09.0000000.9871103.821.0800009.01.038982.720.2214.27
63.81.101.6665.80.0092.0440.00.987113.820.2214.26
\n", "
" ], "text/plain": [ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", - "0 3.8 1.10 0.00 65.8 0.009000 \n", - "1 14.2 0.08 1.66 0.6 0.251377 \n", - "2 3.8 1.10 0.00 0.6 0.009000 \n", - "3 3.8 0.08 1.66 0.6 0.009000 \n", - "4 3.8 1.10 1.66 0.6 0.009000 \n", - "5 3.8 1.10 0.00 0.6 0.009000 \n", - "6 14.2 0.08 1.66 0.6 0.009000 \n", + "0 3.8 0.08 0.00 0.6 0.346 \n", + "1 14.2 1.10 1.66 65.8 0.346 \n", + "2 3.8 1.10 1.66 65.8 0.346 \n", + "3 3.8 1.10 0.00 65.8 0.009 \n", + "4 3.8 1.10 1.66 65.8 0.009 \n", + "5 14.2 1.10 1.66 65.8 0.346 \n", + "6 3.8 1.10 1.66 65.8 0.009 \n", "\n", - " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", - "0 289.0 50.104997 1.038893 3.82 0.220000 \n", - "1 289.0 9.000000 0.987291 3.82 1.080000 \n", - "2 2.0 9.000000 0.987110 3.82 1.080000 \n", - "3 2.0 9.000000 0.987110 3.82 1.080000 \n", - "4 289.0 440.000000 0.987110 3.82 1.080000 \n", - "5 2.0 9.000000 0.987110 3.82 1.079975 \n", - "6 289.0 9.000000 0.987110 3.82 1.080000 \n", + " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", + "0 2.0 440.0 1.03898 3.82 1.08 \n", + "1 289.0 440.0 0.98711 3.82 0.22 \n", + "2 289.0 440.0 0.98711 2.72 0.22 \n", + "3 2.0 440.0 0.98711 2.72 1.08 \n", + "4 289.0 9.0 0.98711 3.82 0.22 \n", + "5 289.0 9.0 1.03898 2.72 0.22 \n", + "6 2.0 440.0 0.98711 3.82 0.22 \n", "\n", " alcohol quality \n", - "0 8.0 5 \n", - "1 8.0 6 \n", - "2 14.2 7 \n", - "3 14.2 6 \n", - "4 14.2 6 \n", - "5 8.0 6 \n", - "6 14.2 7 " + "0 14.2 7 \n", + "1 8.0 5 \n", + "2 14.2 6 \n", + "3 8.0 7 \n", + "4 8.0 7 \n", + "5 14.2 7 \n", + "6 14.2 6 " ] }, - "execution_count": 16, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -1600,7 +1843,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.10.9" } }, "nbformat": 4, From 57816b6c265e2d2949e16e4ff38799d262428b86 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 6 Apr 2023 19:45:50 +0200 Subject: [PATCH 44/95] update pandas and torch version requirement --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7e20f43c..45c9dd4a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,8 +34,8 @@ python_requires = >=3.7 install_requires = scikit-learn>=1.0 nflows>=0.14 - pandas>=1.3 - torch>=1.10.0 + pandas>=2.0 + torch>=2.0 numpy>=1.20 lifelines>=0.27 opacus>=1.3 From cc7e8fb50a5f9a241c4798d7954f4e157e18c1e7 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 6 Apr 2023 19:46:28 +0200 Subject: [PATCH 45/95] update pandas and torch version requirement --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7be378de..ad8686c2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,8 +34,8 @@ python_requires = >=3.7 install_requires = scikit-learn>=1.0 nflows>=0.14 - pandas>=1.3 - torch>=1.10.0 + pandas>=2.0 + torch>=2.0 numpy>=1.20 lifelines>=0.27 opacus>=1.3 From 8a589966297a34e6a37526d96fa192e55423ef85 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 6 Apr 2023 19:57:34 +0200 Subject: [PATCH 46/95] update ddpm tutorial --- .../models/tabular_ddpm/gaussian_multinomial_diffsuion.py | 8 +++++--- .../tutorial8_tabular_modelling_with_diffusion.ipynb | 4 +--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py index 6c11c5be..43ebf091 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py @@ -152,7 +152,8 @@ def __init__( self.posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) - ) + ).to(device) + self.posterior_log_variance_clipped = ( torch.from_numpy( np.log( @@ -162,11 +163,13 @@ def __init__( .float() .to(device) ) + self.posterior_mean_coef1 = ( - (betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + ((betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))) .float() .to(device) ) + self.posterior_mean_coef2 = ( ( (1.0 - alphas_cumprod_prev) @@ -288,7 +291,6 @@ def gaussian_p_mean_variance( ], dim=0, ) - # model_variance = self.posterior_variance.to(x.device) model_log_variance = torch.log(model_variance) model_variance = perm_and_expand(model_variance, t, x.shape) diff --git a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb index d07618a1..686c67f5 100644 --- a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb +++ b/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb @@ -1581,9 +1581,7 @@ ], "source": [ "import random\n", - "from sklearn.preprocessing import LabelEncoder\n", - "cond = random.choices(['red', 'white', 'rose'], k=len(loader))\n", - "cond = LabelEncoder().fit_transform(cond)\n", + "cond = random.choices(outcome, k=len(loader))\n", "plugin.fit(loader, cond=cond)" ] }, From cef348ef2c0450ba2adb73b691711143caaf0eb1 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 6 Apr 2023 20:46:21 +0200 Subject: [PATCH 47/95] restore setup.cfg --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 17ac704e..06c9fd74 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,8 +34,8 @@ python_requires = >=3.7 install_requires = scikit-learn>=1.0 nflows>=0.14 - pandas>=2.0 - torch>=2.0 + pandas>=1.3 + torch>=1.10.0 numpy>=1.20 lifelines>=0.27 opacus>=1.3 From 9cb5da17b54afaab6f50745f438f585b0ce2e965 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 6 Apr 2023 20:47:08 +0200 Subject: [PATCH 48/95] restore setup.cfg --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index cab8c527..de0ac067 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,8 +34,8 @@ python_requires = >=3.7 install_requires = scikit-learn>=1.0 nflows>=0.14 - pandas>=2.0 - torch>=2.0 + pandas>=1.3 + torch>=1.10.0 numpy>=1.20 lifelines>=0.27 opacus>=1.3 From fe5ff2552048b2f7a6abd7c03b69ef9674795ce4 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 7 Apr 2023 08:55:17 +0200 Subject: [PATCH 49/95] replace LabelEncoder with OrdinalEncoder --- src/synthcity/plugins/core/models/factory.py | 4 ++-- src/synthcity/plugins/core/models/feature_encoder.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/synthcity/plugins/core/models/factory.py b/src/synthcity/plugins/core/models/factory.py index e2d69525..40b6d6d6 100644 --- a/src/synthcity/plugins/core/models/factory.py +++ b/src/synthcity/plugins/core/models/factory.py @@ -12,9 +12,9 @@ DatetimeEncoder, FeatureEncoder, GaussianQuantileTransformer, - LabelEncoder, MinMaxScaler, OneHotEncoder, + OrdinalEncoder, RobustScaler, StandardScaler, ) @@ -54,7 +54,7 @@ FEATURE_ENCODERS = dict( datetime=DatetimeEncoder, onehot=OneHotEncoder, - label=LabelEncoder, + ordinal=OrdinalEncoder, standard=StandardScaler, minmax=MinMaxScaler, robust=RobustScaler, diff --git a/src/synthcity/plugins/core/models/feature_encoder.py b/src/synthcity/plugins/core/models/feature_encoder.py index 70807e31..93455162 100644 --- a/src/synthcity/plugins/core/models/feature_encoder.py +++ b/src/synthcity/plugins/core/models/feature_encoder.py @@ -8,9 +8,9 @@ from sklearn.base import BaseEstimator, TransformerMixin from sklearn.mixture import BayesianGaussianMixture from sklearn.preprocessing import ( - LabelEncoder, MinMaxScaler, OneHotEncoder, + OrdinalEncoder, QuantileTransformer, RobustScaler, StandardScaler, @@ -151,7 +151,7 @@ def get_feature_names_out(self) -> List[str]: OneHotEncoder = FeatureEncoder.wraps(OneHotEncoder, categorical=True) -LabelEncoder = FeatureEncoder.wraps(LabelEncoder, n_dim_out=1, categorical=True) +OrdinalEncoder = FeatureEncoder.wraps(OrdinalEncoder, categorical=True) StandardScaler = FeatureEncoder.wraps(StandardScaler) MinMaxScaler = FeatureEncoder.wraps(MinMaxScaler) RobustScaler = FeatureEncoder.wraps(RobustScaler) From 2922a1dd77342b0676acbe263b15a33b4780b3d8 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 7 Apr 2023 08:56:33 +0200 Subject: [PATCH 50/95] update setup.cfg --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index de0ac067..ef292c71 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,8 +34,8 @@ python_requires = >=3.7 install_requires = scikit-learn>=1.0 nflows>=0.14 - pandas>=1.3 - torch>=1.10.0 + pandas>=1.3,<2.0 + torch>=1.10.0,<2.0 numpy>=1.20 lifelines>=0.27 opacus>=1.3 From 11fb825f5a35a9ff6238f362234c9d7c8d481195 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 7 Apr 2023 08:57:14 +0200 Subject: [PATCH 51/95] update setup.cfg --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 06c9fd74..5788d984 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,8 +34,8 @@ python_requires = >=3.7 install_requires = scikit-learn>=1.0 nflows>=0.14 - pandas>=1.3 - torch>=1.10.0 + pandas>=1.3,<2.0 + torch>=1.10.0,<2.0 numpy>=1.20 lifelines>=0.27 opacus>=1.3 From 9222b4e5bb9ec27b84362a902706884a6cb8b525 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 7 Apr 2023 09:57:13 +0200 Subject: [PATCH 52/95] debug datetimeDistribution --- src/synthcity/plugins/core/distribution.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/synthcity/plugins/core/distribution.py b/src/synthcity/plugins/core/distribution.py index 96ce24db..485c9a5a 100644 --- a/src/synthcity/plugins/core/distribution.py +++ b/src/synthcity/plugins/core/distribution.py @@ -383,25 +383,29 @@ class DatetimeDistribution(Distribution): :parts: 1 """ + offset: int = 120 low: datetime = datetime.utcfromtimestamp(0) high: datetime = datetime.now() - offset: int = 120 + + @validator("offset", always=True) + def _validate_offset(cls: Any, v: int) -> int: + if v < 0: + raise ValueError("offset must be greater than 0") + return v @validator("low", always=True) def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime: mkey = "marginal_distribution" if mkey in values and values[mkey] is not None: v = values[mkey].index.min() - - return v - timedelta(seconds=cls.offset) + return v - timedelta(seconds=values["offset"]) @validator("high", always=True) def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime: mkey = "marginal_distribution" if mkey in values and values[mkey] is not None: v = values[mkey].index.max() - - return v + timedelta(seconds=cls.offset) + return v + timedelta(seconds=values["offset"]) def get(self) -> List[Any]: return [self.name, self.low, self.high] From 95302b9c9c9fe9abcfbf0b7c008dca83fde6dc29 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 7 Apr 2023 10:00:38 +0200 Subject: [PATCH 53/95] clean --- tests/plugins/generic/test_ddpm.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index 8fc4664a..733d367c 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -23,12 +23,6 @@ num_timesteps=100, model_type="mlp", ) -# plugin_params = dict( -# n_iter=1000, -# batch_size=1000, -# num_timesteps=30, -# model_type="tabnet", -# ) def extend_fixtures( From 785db826b5f8433ec387814f6c4cdc606f8e9dd2 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 7 Apr 2023 12:34:20 +0200 Subject: [PATCH 54/95] update setup.cfg and goggle test --- setup.cfg | 2 -- tests/plugins/generic/test_goggle.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 5788d984..caa59d0e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,8 +58,6 @@ install_requires = monai tsai; python_version>"3.7" importlib-metadata; python_version<"3.8" - igraph - pytest-cov [options.packages.find] where = src diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 9b194ae0..2c9b5f4a 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -17,7 +17,7 @@ plugin_name = "goggle" plugin_args = { - "n_iter": 10, + "n_iter": 500, "device": "cpu", } From 27cc95c2a6547341965ad3f0af492b4a7c179a71 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 7 Apr 2023 22:17:14 +0200 Subject: [PATCH 55/95] move DDPM tutorial to tutorials/plugins --- tests/plugins/generic/test_ddpm.py | 2 +- .../generic/plugin_ddpm.ipynb} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename tutorials/{tutorial8_tabular_modelling_with_diffusion.ipynb => plugins/generic/plugin_ddpm.ipynb} (99%) diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index 733d367c..c6e3e319 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -18,7 +18,7 @@ plugin_name = "ddpm" plugin_params = dict( - n_iter=1000, + n_iter=500, batch_size=1000, num_timesteps=100, model_type="mlp", diff --git a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb b/tutorials/plugins/generic/plugin_ddpm.ipynb similarity index 99% rename from tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb rename to tutorials/plugins/generic/plugin_ddpm.ipynb index 686c67f5..14ed6929 100644 --- a/tutorials/tutorial8_tabular_modelling_with_diffusion.ipynb +++ b/tutorials/plugins/generic/plugin_ddpm.ipynb @@ -6,7 +6,7 @@ "id": "97e2d93c", "metadata": {}, "source": [ - "# Tutorial 8: Modelling tabular data with diffusion models\n", + "# Modelling tabular data with diffusion models\n", "\n", "This tutorial demonstrates hot to use a denoising diffusion probabilistic model (DDPM) to synthesize tabular data. The algorithm was proposed in [TabDDPM: Modelling Tabular Data with Diffusion Models](https://arxiv.org/abs/2209.15421)." ] From 1d7c77c880e9529bfcfc9a4f5e2f7beb81148b42 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 7 Apr 2023 22:17:53 +0200 Subject: [PATCH 56/95] update tabnet.py reference --- src/synthcity/plugins/core/models/tabnet.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/synthcity/plugins/core/models/tabnet.py b/src/synthcity/plugins/core/models/tabnet.py index c5fc702f..353dabb8 100644 --- a/src/synthcity/plugins/core/models/tabnet.py +++ b/src/synthcity/plugins/core/models/tabnet.py @@ -1,3 +1,10 @@ +""" +TabNet: Attentive Interpretable Tabular Learning +Reference: +- https://arxiv.org/pdf/1908.07442.pdf +- https://github.com/dreamquark-ai/tabnet +""" + # stdlib from typing import List, Optional, Tuple From 6c25377282b29cf829c2f2c85f5176efabe57194 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Fri, 7 Apr 2023 23:37:19 +0200 Subject: [PATCH 57/95] update tab_ddpm --- .../core/models/tabular_ddpm/__init__.py | 14 +++++------- src/synthcity/plugins/generic/plugin_ddpm.py | 2 -- tests/plugins/generic/test_ddpm.py | 22 +++++++++++-------- tests/plugins/generic/test_goggle.py | 3 +++ 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py index b70d0d33..cb4ea67a 100644 --- a/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py +++ b/src/synthcity/plugins/core/models/tabular_ddpm/__init__.py @@ -10,7 +10,7 @@ from pydantic import validate_arguments from torch import nn from torch.utils.data import DataLoader, TensorDataset -from tqdm import tqdm +from tqdm import trange # synthcity absolute from synthcity.logger import info @@ -38,7 +38,6 @@ def __init__( callbacks: Sequence[Callback] = (), device: torch.device = DEVICE, log_interval: int = 10, - print_interval: int = 100, # model params model_type: str = "mlp", model_params: Optional[dict] = None, @@ -141,8 +140,9 @@ def fit( curr_loss_multi = 0.0 curr_loss_gauss = 0.0 curr_count = 0 + pbar = trange(self.n_iter, desc="Epoch", leave=True) - for epoch in tqdm(range(self.n_iter)): + for epoch in pbar: self.epoch = epoch + 1 self.diffusion.train() @@ -166,21 +166,19 @@ def fit( if steps % self.log_interval == 0: mloss = np.around(curr_loss_multi / curr_count, 4) gloss = np.around(curr_loss_gauss / curr_count, 4) - if steps % self.print_interval == 0: - info( - f"Step {steps}: MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}" - ) + loss = mloss + gloss self.loss_history.append( [ steps, mloss, gloss, - mloss + gloss, + loss, ] ) curr_count = 0 curr_loss_gauss = 0.0 curr_loss_multi = 0.0 + pbar.set_postfix(loss=loss) self._update_ema( self.ema_model.parameters(), self.diffusion.parameters() diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 09826a97..cc851a8e 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -103,7 +103,6 @@ def __init__( device: Any = DEVICE, callbacks: Sequence[Callback] = (), log_interval: int = 100, - print_interval: int = 500, model_type: str = "mlp", model_params: dict = {}, dim_embed: int = 128, @@ -139,7 +138,6 @@ def __init__( device=device, callbacks=callbacks, log_interval=log_interval, - print_interval=print_interval, model_type=model_type, model_params=model_params.copy(), dim_embed=dim_embed, diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index c6e3e319..ae0462d8 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -22,6 +22,7 @@ batch_size=1000, num_timesteps=100, model_type="mlp", + sampling_patience=100, ) @@ -84,19 +85,21 @@ def test_plugin_generate(test_plugin: Plugin) -> None: "test_plugin", extend_fixtures(is_classification=[True, False]) ) def test_plugin_generate_constraints(test_plugin: Plugin) -> None: - X = pd.DataFrame(load_iris()["data"]) + X, y = load_iris(as_frame=True, return_X_y=True) + X["target"] = y test_plugin.fit(GenericDataLoader(X)) constraints = Constraints( rules=[ - ("0", "le", 6), - ("0", "ge", 4.3), - ("1", "le", 4.4), - ("1", "ge", 3), - ("2", "le", 5.5), - ("2", "ge", 1.0), - ("3", "le", 2), - ("3", "ge", 0.1), + ("target", "eq", 1), + ("sepal length (cm)", "le", 6), + ("sepal length (cm)", "ge", 4.3), + ("sepal width (cm)", "le", 4.4), + ("sepal width (cm)", "ge", 3), + ("petal length (cm)", "le", 5.5), + ("petal length (cm)", "ge", 1.0), + ("petal width (cm)", "le", 2), + ("petal width (cm)", "ge", 0.1), ] ) @@ -104,6 +107,7 @@ def test_plugin_generate_constraints(test_plugin: Plugin) -> None: assert len(X_gen) == len(X) assert test_plugin.schema_includes(X_gen) assert constraints.filter(X_gen).sum() == len(X_gen) + assert (X_gen["target"] == 1).all() X_gen = test_plugin.generate(count=50, constraints=constraints).dataframe() assert len(X_gen) == 50 diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 2c9b5f4a..9b58ac4e 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -106,6 +106,9 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: assert (X_gen1.numpy() != X_gen3.numpy()).any() +is_missing_goggle_deps = True + + @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", From 3623d37c8e075e1b085b640cbd899b3748daf490 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sat, 8 Apr 2023 17:52:04 +0200 Subject: [PATCH 58/95] update distribution, add optuna utils and tutorial --- src/synthcity/plugins/core/distribution.py | 134 +- src/synthcity/plugins/generic/plugin_ddpm.py | 20 +- src/synthcity/utils/optuna_sample.py | 33 + ...utorial8_hyperparameter_optimization.ipynb | 9862 +++++++++++++++++ 4 files changed, 9980 insertions(+), 69 deletions(-) create mode 100644 src/synthcity/utils/optuna_sample.py create mode 100644 tutorials/tutorial8_hyperparameter_optimization.ipynb diff --git a/src/synthcity/plugins/core/distribution.py b/src/synthcity/plugins/core/distribution.py index fb486e0a..06b3a99b 100644 --- a/src/synthcity/plugins/core/distribution.py +++ b/src/synthcity/plugins/core/distribution.py @@ -111,17 +111,25 @@ def as_constraint(self) -> Constraints: @abstractmethod def min(self) -> Any: - "Get the min value of the distribution" + """Get the min value of the distribution.""" ... @abstractmethod def max(self) -> Any: - "Get the max value of the distribution" + """Get the max value of the distribution.""" ... - @abstractmethod def __eq__(self, other: Any) -> bool: - ... + return type(self) == type(other) and self.get() == other.get() + + def __contains__(self, item: Any) -> bool: + """ + Example: + >>> dist = CategoricalDistribution(name="foo", choices=["a", "b", "c"]) + >>> "a" in dist + True + """ + return self.has(item) @abstractmethod def dtype(self) -> str: @@ -146,7 +154,7 @@ def _validate_choices(cls: Any, v: List, values: Dict) -> List: raise ValueError( "Invalid choices for CategoricalDistribution. Provide data or choices params" ) - return v + return sorted(set(v)) def get(self) -> List[Any]: return [self.name, self.choices] @@ -176,12 +184,6 @@ def min(self) -> Any: def max(self) -> Any: return max(self.choices) - def __eq__(self, other: Any) -> bool: - if not isinstance(other, CategoricalDistribution): - return False - - return self.name == other.name and set(self.choices) == set(other.choices) - def dtype(self) -> str: types = { "object": 0, @@ -259,20 +261,26 @@ def min(self) -> Any: def max(self) -> Any: return self.high - def __eq__(self, other: Any) -> bool: - if not isinstance(other, FloatDistribution): - return False - - return ( - self.name == other.name - and self.low == other.low - and self.high == other.high - ) - def dtype(self) -> str: return "float" +class LogDistribution(FloatDistribution): + low: float = np.finfo(np.float64).tiny + high: float = np.finfo(np.float64).max + + def get(self) -> List[Any]: + return [self.name, self.low, self.high] + + def sample(self, count: int = 1) -> Any: + np.random.seed(self.random_state) + msamples = self.sample_marginal(count) + if msamples is not None: + return msamples + lo, hi = np.log2(self.low), np.log2(self.high) + return 2.0 ** np.random.uniform(lo, hi, count) + + class IntegerDistribution(Distribution): """ .. inheritance-diagram:: synthcity.plugins.core.distribution.IntegerDistribution @@ -298,6 +306,12 @@ def _validate_high_thresh(cls: Any, v: int, values: Dict) -> int: return int(values[mkey].index.max()) return v + @validator("step", always=True) + def _validate_step(cls: Any, v: int, values: Dict) -> int: + if v < 1: + raise ValueError("Step must be greater than 0") + return v + def get(self) -> List[Any]: return [self.name, self.low, self.high, self.step] @@ -307,8 +321,9 @@ def sample(self, count: int = 1) -> Any: if msamples is not None: return msamples - choices = [val for val in range(self.low, self.high + 1, self.step)] - return np.random.choice(choices, count).tolist() + steps = (self.high - self.low) // self.step + samples = np.random.choice(steps + 1, count) + return samples * self.step + self.low def has(self, val: Any) -> bool: return self.low <= val and val <= self.high @@ -331,21 +346,33 @@ def min(self) -> Any: def max(self) -> Any: return self.high - def __eq__(self, other: Any) -> bool: - if not isinstance(other, IntegerDistribution): - return False - - return ( - self.name == other.name - and self.low == other.low - and self.high == other.high - ) - def dtype(self) -> str: return "int" -OFFSET = 120 +class IntLogDistribution(IntegerDistribution): + low: int = 1 + high: int = np.iinfo(np.int64).max + step: int = 2 # the next sample larger than x is step * x + + @validator("step", always=True) + def _validate_step(cls: Any, v: int, values: Dict) -> int: + if v < 2: + raise ValueError("Step must be greater than 1") + return v + + def get(self) -> List[Any]: + return [self.name, self.low, self.high, self.step] + + def sample(self, count: int = 1) -> Any: + np.random.seed(self.random_state) + msamples = self.sample_marginal(count) + if msamples is not None: + return msamples + steps = int(np.log2(self.high / self.low) / np.log2(self.step)) + samples = np.random.choice(steps + 1, count) + samples = self.low * self.step**samples + return samples.astype(int) class DatetimeDistribution(Distribution): @@ -356,25 +383,25 @@ class DatetimeDistribution(Distribution): low: datetime = datetime.utcfromtimestamp(0) high: datetime = datetime.now() + step: timedelta = timedelta(microseconds=1) + offset: timedelta = timedelta(seconds=120) @validator("low", always=True) def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime: mkey = "marginal_distribution" if mkey in values and values[mkey] is not None: v = values[mkey].index.min() - - return v - timedelta(seconds=OFFSET) + return v @validator("high", always=True) def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime: mkey = "marginal_distribution" if mkey in values and values[mkey] is not None: v = values[mkey].index.max() - - return v + timedelta(seconds=OFFSET) + return v def get(self) -> List[Any]: - return [self.name, self.low, self.high] + return [self.name, self.low, self.high, self.step, self.offset] def sample(self, count: int = 1) -> Any: np.random.seed(self.random_state) @@ -382,23 +409,18 @@ def sample(self, count: int = 1) -> Any: if msamples is not None: return msamples - samples = np.random.uniform( - datetime.timestamp(self.low), datetime.timestamp(self.high), count - ) - - samples_dt = [] - for s in samples: - samples_dt.append(datetime.fromtimestamp(s)) - - return samples_dt + n = (self.high - self.low) // self.step + 1 + samples = np.round(np.random.rand(count) * n - 0.5) + return self.low + samples * self.step def has(self, val: datetime) -> bool: return self.low <= val and val <= self.high def includes(self, other: "Distribution") -> bool: - return self.min() - timedelta( - seconds=OFFSET - ) <= other.min() and other.max() <= self.max() + timedelta(seconds=OFFSET) + return ( + self.min() - self.offset <= other.min() + and other.max() <= self.max() + self.offset + ) def as_constraint(self) -> Constraints: return Constraints( @@ -415,16 +437,6 @@ def min(self) -> Any: def max(self) -> Any: return self.high - def __eq__(self, other: Any) -> bool: - if not isinstance(other, DatetimeDistribution): - return False - - return ( - self.name == other.name - and self.low == other.low - and self.high == other.high - ) - def dtype(self) -> str: return "datetime" diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 631480fc..0af21cf7 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -15,7 +15,12 @@ # synthcity absolute from synthcity.plugins.core.dataloader import DataLoader -from synthcity.plugins.core.distribution import CategoricalDistribution, Distribution +from synthcity.plugins.core.distribution import ( + Distribution, + IntegerDistribution, + IntLogDistribution, + LogDistribution, +) from synthcity.plugins.core.models.tabular_ddpm import TabDDPM from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema @@ -174,13 +179,12 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]: Gaussian diffusion loss MSE """ return [ - # TODO: change to loguniform distribution - CategoricalDistribution(name="lr", choices=[1e-5, 1e-4, 1e-3, 2e-3, 3e-3]), - CategoricalDistribution(name="batch_size", choices=[256, 4096]), - CategoricalDistribution(name="num_timesteps", choices=[100, 1000]), - CategoricalDistribution(name="n_iter", choices=[5000, 10000, 20000]), - CategoricalDistribution(name="n_layers_hidden", choices=[2, 4, 6, 8]), - CategoricalDistribution(name="dim_hidden", choices=[128, 256, 512, 1024]), + LogDistribution(name="lr", low=1e-5, high=1e-1), + IntLogDistribution(name="batch_size", low=256, high=4096), + IntegerDistribution(name="num_timesteps", low=10, high=1000), + IntLogDistribution(name="n_iter", low=1000, high=10000), + IntegerDistribution(name="n_layers_hidden", low=2, high=8), + IntLogDistribution(name="dim_hidden", low=128, high=1024), ] def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": diff --git a/src/synthcity/utils/optuna_sample.py b/src/synthcity/utils/optuna_sample.py new file mode 100644 index 00000000..c19dae66 --- /dev/null +++ b/src/synthcity/utils/optuna_sample.py @@ -0,0 +1,33 @@ +# stdlib +from typing import Any, Dict, List + +# third party +import optuna + +# synthcity absolute +import synthcity.plugins.core.distribution as D + + +def suggest(trial: optuna.Trial, dist: D.Distribution) -> Any: + if isinstance(dist, D.FloatDistribution): + return trial.suggest_float(dist.name, dist.low, dist.high) + elif isinstance(dist, D.LogDistribution): + return trial.suggest_float(dist.name, dist.low, dist.high, log=True) + elif isinstance(dist, D.IntegerDistribution): + return trial.suggest_int(dist.name, dist.low, dist.high, dist.step) + elif isinstance(dist, D.IntLogDistribution): + # ! does not handle step yet + return trial.suggest_int(dist.name, dist.low, dist.high, log=True) + elif isinstance(dist, D.CategoricalDistribution): + return trial.suggest_categorical(dist.name, dist.choices) + # ! the modification cannot be reflected in study.best_params + # elif isinstance(dist, D.DatetimeDistribution): + # high = (dist.high - dist.low) / dist.step + # s = trial.suggest_float(dist.name, 0, high) + # return dist.low + dist.step * s + else: + raise ValueError(f"Unknown dist: {dist}") + + +def suggest_all(trial: optuna.Trial, distributions: List[D.Distribution]) -> Dict: + return {dist.name: suggest(trial, dist) for dist in distributions} diff --git a/tutorials/tutorial8_hyperparameter_optimization.ipynb b/tutorials/tutorial8_hyperparameter_optimization.ipynb new file mode 100644 index 00000000..4cf7c965 --- /dev/null +++ b/tutorials/tutorial8_hyperparameter_optimization.ipynb @@ -0,0 +1,9862 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tutorial 8: Hyperparameter Optimization\n", + "\n", + "To automatically tune hyperparameters in a `synthcity` plugin to generate more realistic data, we use hyperparameter optimization (HPO) algorithms such as Tree-structured Parzen estimators (TPE), Bayesian optimization, and genetic programming. In this tutorial we will use `optuna`, a very popular HPO library implementing TPE, to tune the hyperparameters of the `nflow` plugin to synthesize the diabetes dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[KeOps] Warning : \n", + " The default C++ compiler could not be found on your system.\n", + " You need to either define the CXX environment variable or a symlink to the g++ command.\n", + " For example if g++-8 is the command you can do\n", + " import os\n", + " os.environ['CXX'] = 'g++-8'\n", + " \n", + "[KeOps] Warning : Cuda libraries were not detected on the system ; using cpu only mode\n" + ] + } + ], + "source": [ + "# stdlib\n", + "import sys\n", + "import warnings\n", + "\n", + "# third party\n", + "import optuna\n", + "from sklearn.datasets import load_diabetes\n", + "\n", + "# synthcity absolute\n", + "import synthcity.logger as log\n", + "from synthcity.plugins import Plugins\n", + "from synthcity.plugins.core.dataloader import GenericDataLoader\n", + "\n", + "log.add(sink=sys.stderr, level=\"INFO\")\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
agesexbmibps1s2s3s4s5s6target
00.0380760.0506800.0616960.021872-0.044223-0.034821-0.043401-0.0025920.019907-0.017646151.0
1-0.001882-0.044642-0.051474-0.026328-0.008449-0.0191630.074412-0.039493-0.068332-0.09220475.0
20.0852990.0506800.044451-0.005670-0.045599-0.034194-0.032356-0.0025920.002861-0.025930141.0
3-0.089063-0.044642-0.011595-0.0366560.0121910.024991-0.0360380.0343090.022688-0.009362206.0
40.005383-0.044642-0.0363850.0218720.0039350.0155960.008142-0.002592-0.031988-0.046641135.0
....................................
4370.0417080.0506800.0196620.059744-0.005697-0.002566-0.028674-0.0025920.0311930.007207178.0
438-0.0055150.050680-0.015906-0.0676420.0493410.079165-0.0286740.034309-0.0181140.044485104.0
4390.0417080.050680-0.0159060.017293-0.037344-0.013840-0.024993-0.011080-0.0468830.015491132.0
440-0.045472-0.0446420.0390620.0012150.0163180.015283-0.0286740.0265600.044529-0.025930220.0
441-0.045472-0.044642-0.073030-0.0814130.0837400.0278090.173816-0.039493-0.0042220.00306457.0
\n", + "

442 rows × 11 columns

\n", + "
" + ], + "text/plain": [ + " age sex bmi bp s1 s2 s3 \\\n", + "0 0.038076 0.050680 0.061696 0.021872 -0.044223 -0.034821 -0.043401 \n", + "1 -0.001882 -0.044642 -0.051474 -0.026328 -0.008449 -0.019163 0.074412 \n", + "2 0.085299 0.050680 0.044451 -0.005670 -0.045599 -0.034194 -0.032356 \n", + "3 -0.089063 -0.044642 -0.011595 -0.036656 0.012191 0.024991 -0.036038 \n", + "4 0.005383 -0.044642 -0.036385 0.021872 0.003935 0.015596 0.008142 \n", + ".. ... ... ... ... ... ... ... \n", + "437 0.041708 0.050680 0.019662 0.059744 -0.005697 -0.002566 -0.028674 \n", + "438 -0.005515 0.050680 -0.015906 -0.067642 0.049341 0.079165 -0.028674 \n", + "439 0.041708 0.050680 -0.015906 0.017293 -0.037344 -0.013840 -0.024993 \n", + "440 -0.045472 -0.044642 0.039062 0.001215 0.016318 0.015283 -0.028674 \n", + "441 -0.045472 -0.044642 -0.073030 -0.081413 0.083740 0.027809 0.173816 \n", + "\n", + " s4 s5 s6 target \n", + "0 -0.002592 0.019907 -0.017646 151.0 \n", + "1 -0.039493 -0.068332 -0.092204 75.0 \n", + "2 -0.002592 0.002861 -0.025930 141.0 \n", + "3 0.034309 0.022688 -0.009362 206.0 \n", + "4 -0.002592 -0.031988 -0.046641 135.0 \n", + ".. ... ... ... ... \n", + "437 -0.002592 0.031193 0.007207 178.0 \n", + "438 0.034309 -0.018114 0.044485 104.0 \n", + "439 -0.011080 -0.046883 0.015491 132.0 \n", + "440 0.026560 0.044529 -0.025930 220.0 \n", + "441 -0.039493 -0.004222 0.003064 57.0 \n", + "\n", + "[442 rows x 11 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X, y = load_diabetes(return_X_y=True, as_frame=True)\n", + "X[\"target\"] = y\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "loader = GenericDataLoader(\n", + " X,\n", + " target_column=\"target\",\n", + " sensitive_columns=[\"sex\"],\n", + ")\n", + "train_loader, test_loader = loader.train(), loader.test()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the plugin class" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-04-07T21:51:56.689921+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n" + ] + }, + { + "data": { + "text/plain": [ + "synthcity.plugins.generic.plugin_nflow.NormalizingFlowsPlugin" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "PLUGIN = \"nflow\"\n", + "plugin_cls = type(Plugins().get(PLUGIN))\n", + "plugin_cls" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display the hyperparameter space" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[IntegerDistribution(name='n_iter', data=None, random_state=0, marginal_distribution=None, low=100, high=5000, step=100),\n", + " IntegerDistribution(name='n_layers_hidden', data=None, random_state=0, marginal_distribution=None, low=1, high=10, step=1),\n", + " IntegerDistribution(name='n_units_hidden', data=None, random_state=0, marginal_distribution=None, low=10, high=100, step=1),\n", + " CategoricalDistribution(name='batch_size', data=None, random_state=0, marginal_distribution=None, choices=[32, 64, 128, 256, 512]),\n", + " FloatDistribution(name='dropout', data=None, random_state=0, marginal_distribution=None, low=0.0, high=0.2),\n", + " CategoricalDistribution(name='batch_norm', data=None, random_state=0, marginal_distribution=None, choices=[True, False]),\n", + " CategoricalDistribution(name='lr', data=None, random_state=0, marginal_distribution=None, choices=[0.001, 0.0001, 0.0002]),\n", + " CategoricalDistribution(name='linear_transform_type', data=None, random_state=0, marginal_distribution=None, choices=['lu', 'permutation', 'svd']),\n", + " CategoricalDistribution(name='base_transform_type', data=None, random_state=0, marginal_distribution=None, choices=['affine-coupling', 'quadratic-coupling', 'rq-coupling', 'affine-autoregressive', 'quadratic-autoregressive', 'rq-autoregressive'])]" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plugin_cls.hyperparameter_space()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use a trial to suggest a set of hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'n_iter': 486,\n", + " 'n_layers_hidden': 10,\n", + " 'n_units_hidden': 87,\n", + " 'batch_size': 512,\n", + " 'dropout': 0.016022465975681178,\n", + " 'batch_norm': True,\n", + " 'lr': 0.001,\n", + " 'linear_transform_type': 'svd',\n", + " 'base_transform_type': 'affine-coupling'}" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from synthcity.utils.optuna_sample import suggest_all\n", + "\n", + "trial = optuna.create_study().ask()\n", + "params = suggest_all(trial, plugin_cls.hyperparameter_space())\n", + "params" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate the plugin with the suggested hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 62%|██████▏ | 299/486 [01:26<00:53, 3.47it/s]\n", + "[2023-04-07T21:53:29.785866+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + " 62%|██████▏ | 299/486 [01:30<00:56, 3.31it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
minmaxmeanstddevmedianiqrroundserrorsdurationsdirection
detection.detection_mlp.mean0.3900710.3900710.3900710.00.3900710.0102.51minimize
\n", + "
" + ], + "text/plain": [ + " min max mean stddev median \\\n", + "detection.detection_mlp.mean 0.390071 0.390071 0.390071 0.0 0.390071 \n", + "\n", + " iqr rounds errors durations direction \n", + "detection.detection_mlp.mean 0.0 1 0 2.51 minimize " + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from synthcity.benchmark import Benchmarks\n", + "\n", + "plugin = plugin_cls(**params).fit(train_loader)\n", + "report = Benchmarks.evaluate(\n", + " [(\"trial\", PLUGIN, params)],\n", + " train_loader, # Benchmarks.evaluate will split out a validation set\n", + " repeats=1,\n", + " metrics={\"detection\": [\"detection_mlp\"]}, # DELETE THIS LINE FOR ALL METRICS\n", + ")\n", + "report['trial']" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create an Optuna study and optimize the hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-04-07T21:57:37.669090+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T21:57:37.689827+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:12<00:00, 8.05it/s]\n", + "[2023-04-07T21:57:53.690237+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T21:57:53.712601+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:04<00:00, 24.77it/s]\n", + "[2023-04-07T21:58:01.728358+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T21:58:01.744010+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:24<00:00, 4.08it/s]\n", + "[2023-04-07T21:58:32.292499+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T21:58:32.316002+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [01:30<00:00, 1.10it/s]\n", + "[2023-04-07T22:00:38.652411+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T22:00:38.685914+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:23<00:00, 4.21it/s]\n", + "[2023-04-07T22:01:09.148491+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T22:01:09.178259+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:06<00:00, 14.79it/s]\n", + "[2023-04-07T22:01:20.722191+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T22:01:20.751419+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [02:00<00:00, 1.21s/it]\n", + "[2023-04-07T22:03:29.180475+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T22:03:29.211421+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [01:37<00:00, 1.02it/s]\n", + "[2023-04-07T22:05:12.012437+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T22:05:12.030781+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:57<00:00, 1.74it/s]\n", + "[2023-04-07T22:06:14.408112+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-07T22:06:14.431469+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:11<00:00, 8.47it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "{'n_iter': 4929,\n", + " 'n_layers_hidden': 1,\n", + " 'n_units_hidden': 65,\n", + " 'batch_size': 256,\n", + " 'dropout': 0.04046713177503456,\n", + " 'batch_norm': True,\n", + " 'lr': 0.001,\n", + " 'linear_transform_type': 'lu',\n", + " 'base_transform_type': 'affine-coupling'}" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def objective(trial: optuna.Trial):\n", + " hp_space = Plugins().get(PLUGIN).hyperparameter_space()\n", + " params = suggest_all(trial, hp_space[1:]) # fix n_iter=100 for speed\n", + " params['n_iter'] = 100\n", + " ID = f\"trial_{trial.number}\"\n", + " report = Benchmarks.evaluate(\n", + " [(ID, PLUGIN, params)],\n", + " train_loader,\n", + " repeats=1,\n", + " metrics={\"detection\": [\"detection_mlp\"]}, # DELETE THIS LINE FOR ALL METRICS\n", + " )\n", + " score = report[ID].query('direction == \"minimize\"')['mean'].mean()\n", + " # average score across all metrics with direction=\"minimize\"\n", + " return score\n", + "\n", + "study = optuna.create_study(direction=\"minimize\")\n", + "study.optimize(objective, n_trials=10)\n", + "study.best_params" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize the study" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "mode": "markers", + "name": "Objective Value", + "type": "scatter", + "x": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9 + ], + "y": [ + 0.5199275362318841, + 0.4701086956521739, + 0.518719806763285, + 0.5, + 0.5, + 0.5, + 0.5, + 0.49516908212560384, + 0.5, + 0.5 + ] + }, + { + "name": "Best Value", + "type": "scatter", + "x": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9 + ], + "y": [ + 0.5199275362318841, + 0.4701086956521739, + 0.4701086956521739, + 0.4701086956521739, + 0.4701086956521739, + 0.4701086956521739, + 0.4701086956521739, + 0.4701086956521739, + 0.4701086956521739, + 0.4701086956521739 + ] + } + ], + "layout": { + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Optimization History Plot" + }, + "xaxis": { + "title": { + "text": "Trial" + } + }, + "yaxis": { + "title": { + "text": "Objective Value" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from optuna.visualization import plot_contour\n", + "from optuna.visualization import plot_edf\n", + "from optuna.visualization import plot_optimization_history\n", + "from optuna.visualization import plot_parallel_coordinate\n", + "from optuna.visualization import plot_param_importances\n", + "from optuna.visualization import plot_slice\n", + "\n", + "plot_optimization_history(study)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "dimensions": [ + { + "label": "Objective Value", + "range": [ + 0.4701086956521739, + 0.5199275362318841 + ], + "values": [ + 0.518719806763285, + 0.49516908212560384, + 0.5, + 0.5, + 0.5199275362318841, + 0.4701086956521739, + 0.5, + 0.5, + 0.5, + 0.5 + ] + }, + { + "label": "base_transform_type", + "range": [ + 0, + 4 + ], + "ticktext": [ + "rq-coupling", + "affine-coupling", + "affine-autoregressive", + "quadratic-autoregressive", + "quadratic-coupling" + ], + "tickvals": [ + 0, + 1, + 2, + 3, + 4 + ], + "values": [ + 2, + 4, + 3, + 1, + 0, + 1, + 4, + 0, + 1, + 1 + ] + }, + { + "label": "batch_norm", + "range": [ + 0, + 1 + ], + "ticktext": [ + "False", + "True" + ], + "tickvals": [ + 0, + 1 + ], + "values": [ + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 1, + 0, + 0 + ] + }, + { + "label": "batch_size", + "range": [ + 0, + 4 + ], + "ticktext": [ + "32", + "64", + "128", + "256", + "512" + ], + "tickvals": [ + 0, + 1, + 2, + 3, + 4 + ], + "values": [ + 0, + 0, + 1, + 1, + 2, + 3, + 3, + 4, + 4, + 4 + ] + }, + { + "label": "dropout", + "range": [ + 0.04046713177503456, + 0.1816709592718398 + ], + "values": [ + 0.051928275495301705, + 0.0768868475443224, + 0.13154994845426374, + 0.13397751341486178, + 0.0805911907341199, + 0.04046713177503456, + 0.16534854040828872, + 0.1816709592718398, + 0.136820133194068, + 0.11575633038847206 + ] + }, + { + "label": "linear_transform_...", + "range": [ + 0, + 2 + ], + "ticktext": [ + "svd", + "lu", + "permutation" + ], + "tickvals": [ + 0, + 1, + 2 + ], + "values": [ + 2, + 2, + 1, + 0, + 0, + 1, + 1, + 2, + 2, + 1 + ] + }, + { + "label": "lr", + "range": [ + 0, + 2 + ], + "ticktext": [ + "0.0001", + "0.0002", + "0.001" + ], + "tickvals": [ + 0, + 1, + 2 + ], + "values": [ + 2, + 2, + 1, + 2, + 2, + 2, + 2, + 0, + 1, + 1 + ] + }, + { + "label": "n_iter", + "range": [ + 249, + 4929 + ], + "values": [ + 3068, + 1368, + 249, + 1525, + 1129, + 4929, + 1595, + 2629, + 3151, + 1295 + ] + }, + { + "label": "n_layers_hidden", + "range": [ + 1, + 9 + ], + "values": [ + 7, + 5, + 7, + 2, + 1, + 1, + 9, + 7, + 4, + 9 + ] + }, + { + "label": "n_units_hidden", + "range": [ + 13, + 99 + ], + "values": [ + 79, + 99, + 80, + 24, + 13, + 65, + 94, + 46, + 62, + 72 + ] + } + ], + "labelangle": 30, + "labelside": "bottom", + "line": { + "color": [ + 0.518719806763285, + 0.49516908212560384, + 0.5, + 0.5, + 0.5199275362318841, + 0.4701086956521739, + 0.5, + 0.5, + 0.5, + 0.5 + ], + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "reversescale": true, + "showscale": true + }, + "type": "parcoords" + } + ], + "layout": { + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Parallel Coordinate Plot" + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize high-dimensional parameter relationships. \n", + "plot_parallel_coordinate(study)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "type": "scatter", + "xaxis": "x", + "yaxis": "y" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": true, + "type": "contour", + "x": [ + 8, + 32, + 64, + 128, + 256, + 512, + 536 + ], + "xaxis": "x5", + "y": [ + 0.0334069404001943, + 0.04046713177503456, + 0.051928275495301705, + 0.0768868475443224, + 0.0805911907341199, + 0.11575633038847206, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.1816709592718398, + 0.18873115064668006 + ], + "yaxis": "y5", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + 0.4701086956521739, + null, + null + ], + [ + null, + 0.518719806763285, + null, + null, + null, + null, + null + ], + [ + null, + 0.49516908212560384, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + 0.5199275362318841, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 128, + 256, + 32, + 64, + 64, + 512, + 256, + 32, + 512, + 512 + ], + "xaxis": "x5", + "y": [ + 0.0805911907341199, + 0.04046713177503456, + 0.051928275495301705, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.0768868475443224, + 0.1816709592718398, + 0.11575633038847206 + ], + "yaxis": "y5" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 8, + 32, + 64, + 128, + 256, + 512, + 536 + ], + "xaxis": "x9", + "y": [ + 0.6, + 1, + 2, + 4, + 5, + 7, + 9, + 9.4 + ], + "yaxis": "y9", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + 0.5199275362318841, + 0.4701086956521739, + null, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + 0.49516908212560384, + null, + null, + null, + null, + null + ], + [ + null, + 0.518719806763285, + 0.5, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + 0.5, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 128, + 256, + 32, + 64, + 64, + 512, + 256, + 32, + 512, + 512 + ], + "xaxis": "x9", + "y": [ + 1, + 1, + 7, + 7, + 2, + 4, + 9, + 5, + 7, + 9 + ], + "yaxis": "y9" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 8, + 32, + 64, + 128, + 256, + 512, + 536 + ], + "xaxis": "x13", + "y": [ + 8.7, + 13, + 24, + 46, + 62, + 65, + 72, + 79, + 80, + 94, + 99, + 103.3 + ], + "yaxis": "y13", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + 0.5199275362318841, + null, + null, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + 0.4701086956521739, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + 0.518719806763285, + null, + null, + null, + null, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + 0.49516908212560384, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 128, + 256, + 32, + 64, + 64, + 512, + 256, + 32, + 512, + 512 + ], + "xaxis": "x13", + "y": [ + 13, + 65, + 79, + 80, + 24, + 62, + 94, + 99, + 46, + 72 + ], + "yaxis": "y13" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 0.0334069404001943, + 0.04046713177503456, + 0.051928275495301705, + 0.0768868475443224, + 0.0805911907341199, + 0.11575633038847206, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.1816709592718398, + 0.18873115064668006 + ], + "xaxis": "x2", + "y": [ + 8, + 32, + 64, + 128, + 256, + 512, + 536 + ], + "yaxis": "y2", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + 0.518719806763285, + 0.49516908212560384, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + 0.5, + 0.5, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + 0.5199275362318841, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + 0.4701086956521739, + null, + null, + null, + null, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null, + null, + 0.5, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 0.0805911907341199, + 0.04046713177503456, + 0.051928275495301705, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.0768868475443224, + 0.1816709592718398, + 0.11575633038847206 + ], + "xaxis": "x2", + "y": [ + 128, + 256, + 32, + 64, + 64, + 512, + 256, + 32, + 512, + 512 + ], + "yaxis": "y2" + }, + { + "type": "scatter", + "xaxis": "x6", + "yaxis": "y6" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 0.0334069404001943, + 0.04046713177503456, + 0.051928275495301705, + 0.0768868475443224, + 0.0805911907341199, + 0.11575633038847206, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.1816709592718398, + 0.18873115064668006 + ], + "xaxis": "x10", + "y": [ + 0.6, + 1, + 2, + 4, + 5, + 7, + 9, + 9.4 + ], + "yaxis": "y10", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + 0.4701086956521739, + null, + null, + 0.5199275362318841, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + 0.5, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + 0.5, + null, + null, + null + ], + [ + null, + null, + null, + 0.49516908212560384, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + 0.518719806763285, + null, + null, + null, + 0.5, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 0.0805911907341199, + 0.04046713177503456, + 0.051928275495301705, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.0768868475443224, + 0.1816709592718398, + 0.11575633038847206 + ], + "xaxis": "x10", + "y": [ + 1, + 1, + 7, + 7, + 2, + 4, + 9, + 5, + 7, + 9 + ], + "yaxis": "y10" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 0.0334069404001943, + 0.04046713177503456, + 0.051928275495301705, + 0.0768868475443224, + 0.0805911907341199, + 0.11575633038847206, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.1816709592718398, + 0.18873115064668006 + ], + "xaxis": "x14", + "y": [ + 8.7, + 13, + 24, + 46, + 62, + 65, + 72, + 79, + 80, + 94, + 99, + 103.3 + ], + "yaxis": "y14", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + 0.5199275362318841, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + 0.5, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + 0.5, + null, + null, + null + ], + [ + null, + 0.4701086956521739, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + 0.518719806763285, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + 0.5, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + 0.49516908212560384, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 0.0805911907341199, + 0.04046713177503456, + 0.051928275495301705, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.0768868475443224, + 0.1816709592718398, + 0.11575633038847206 + ], + "xaxis": "x14", + "y": [ + 13, + 65, + 79, + 80, + 24, + 62, + 94, + 99, + 46, + 72 + ], + "yaxis": "y14" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 0.6, + 1, + 2, + 4, + 5, + 7, + 9, + 9.4 + ], + "xaxis": "x3", + "y": [ + 8, + 32, + 64, + 128, + 256, + 512, + 536 + ], + "yaxis": "y3", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + 0.49516908212560384, + 0.518719806763285, + null, + null + ], + [ + null, + null, + 0.5, + null, + null, + 0.5, + null, + null + ], + [ + null, + 0.5199275362318841, + null, + null, + null, + null, + null, + null + ], + [ + null, + 0.4701086956521739, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + 0.5, + null, + 0.5, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 1, + 1, + 7, + 7, + 2, + 4, + 9, + 5, + 7, + 9 + ], + "xaxis": "x3", + "y": [ + 128, + 256, + 32, + 64, + 64, + 512, + 256, + 32, + 512, + 512 + ], + "yaxis": "y3" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 0.6, + 1, + 2, + 4, + 5, + 7, + 9, + 9.4 + ], + "xaxis": "x7", + "y": [ + 0.0334069404001943, + 0.04046713177503456, + 0.051928275495301705, + 0.0768868475443224, + 0.0805911907341199, + 0.11575633038847206, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.1816709592718398, + 0.18873115064668006 + ], + "yaxis": "y7", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + 0.4701086956521739, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.518719806763285, + null, + null + ], + [ + null, + null, + null, + null, + 0.49516908212560384, + null, + null, + null + ], + [ + null, + 0.5199275362318841, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + 0.5, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 1, + 1, + 7, + 7, + 2, + 4, + 9, + 5, + 7, + 9 + ], + "xaxis": "x7", + "y": [ + 0.0805911907341199, + 0.04046713177503456, + 0.051928275495301705, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.0768868475443224, + 0.1816709592718398, + 0.11575633038847206 + ], + "yaxis": "y7" + }, + { + "type": "scatter", + "xaxis": "x11", + "yaxis": "y11" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 0.6, + 1, + 2, + 4, + 5, + 7, + 9, + 9.4 + ], + "xaxis": "x15", + "y": [ + 8.7, + 13, + 24, + 46, + 62, + 65, + 72, + 79, + 80, + 94, + 99, + 103.3 + ], + "yaxis": "y15", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + 0.5199275362318841, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + 0.5, + null, + null, + null, + null + ], + [ + null, + 0.4701086956521739, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + null, + 0.518719806763285, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + 0.5, + null + ], + [ + null, + null, + null, + null, + 0.49516908212560384, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 1, + 1, + 7, + 7, + 2, + 4, + 9, + 5, + 7, + 9 + ], + "xaxis": "x15", + "y": [ + 13, + 65, + 79, + 80, + 24, + 62, + 94, + 99, + 46, + 72 + ], + "yaxis": "y15" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 8.7, + 13, + 24, + 46, + 62, + 65, + 72, + 79, + 80, + 94, + 99, + 103.3 + ], + "xaxis": "x4", + "y": [ + 8, + 32, + 64, + 128, + 256, + 512, + 536 + ], + "yaxis": "y4", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + 0.518719806763285, + null, + null, + 0.49516908212560384, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null, + null, + 0.5, + null, + null, + null + ], + [ + null, + 0.5199275362318841, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.4701086956521739, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + 0.5, + 0.5, + null, + 0.5, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 13, + 65, + 79, + 80, + 24, + 62, + 94, + 99, + 46, + 72 + ], + "xaxis": "x4", + "y": [ + 128, + 256, + 32, + 64, + 64, + 512, + 256, + 32, + 512, + 512 + ], + "yaxis": "y4" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 8.7, + 13, + 24, + 46, + 62, + 65, + 72, + 79, + 80, + 94, + 99, + 103.3 + ], + "xaxis": "x8", + "y": [ + 0.0334069404001943, + 0.04046713177503456, + 0.051928275495301705, + 0.0768868475443224, + 0.0805911907341199, + 0.11575633038847206, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.1816709592718398, + 0.18873115064668006 + ], + "yaxis": "y8", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + 0.4701086956521739, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + 0.518719806763285, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + 0.49516908212560384, + null + ], + [ + null, + 0.5199275362318841, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + 0.5, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + 0.5, + null, + null, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + 0.5, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + 0.5, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 13, + 65, + 79, + 80, + 24, + 62, + 94, + 99, + 46, + 72 + ], + "xaxis": "x8", + "y": [ + 0.0805911907341199, + 0.04046713177503456, + 0.051928275495301705, + 0.13154994845426374, + 0.13397751341486178, + 0.136820133194068, + 0.16534854040828872, + 0.0768868475443224, + 0.1816709592718398, + 0.11575633038847206 + ], + "yaxis": "y8" + }, + { + "colorbar": { + "title": { + "text": "Objective Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "connectgaps": true, + "contours": { + "coloring": "heatmap" + }, + "hoverinfo": "none", + "line": { + "smoothing": 1.3 + }, + "reversescale": true, + "showscale": false, + "type": "contour", + "x": [ + 8.7, + 13, + 24, + 46, + 62, + 65, + 72, + 79, + 80, + 94, + 99, + 103.3 + ], + "xaxis": "x12", + "y": [ + 0.6, + 1, + 2, + 4, + 5, + 7, + 9, + 9.4 + ], + "yaxis": "y12", + "z": [ + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + 0.5199275362318841, + null, + null, + null, + 0.4701086956521739, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + 0.5, + null, + null, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + 0.5, + null, + null, + null, + null, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + 0.49516908212560384, + null + ], + [ + null, + null, + null, + 0.5, + null, + null, + null, + 0.518719806763285, + 0.5, + null, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + 0.5, + null, + null, + 0.5, + null, + null + ], + [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ] + ] + }, + { + "marker": { + "color": "black", + "line": { + "color": "Grey", + "width": 2 + } + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 13, + 65, + 79, + 80, + 24, + 62, + 94, + 99, + 46, + 72 + ], + "xaxis": "x12", + "y": [ + 1, + 1, + 7, + 7, + 2, + 4, + 9, + 5, + 7, + 9 + ], + "yaxis": "y12" + }, + { + "type": "scatter", + "xaxis": "x16", + "yaxis": "y16" + } + ], + "layout": { + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Contour Plot" + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 0.2125 + ], + "matches": "x13", + "range": [ + 8, + 536 + ], + "showticklabels": false + }, + "xaxis10": { + "anchor": "y10", + "domain": [ + 0.2625, + 0.475 + ], + "matches": "x14", + "range": [ + 0.0334069404001943, + 0.18873115064668006 + ], + "showticklabels": false + }, + "xaxis11": { + "anchor": "y11", + "domain": [ + 0.525, + 0.7375 + ], + "matches": "x15", + "range": [ + 0.6, + 9.4 + ], + "showticklabels": false + }, + "xaxis12": { + "anchor": "y12", + "domain": [ + 0.7875, + 1 + ], + "matches": "x16", + "range": [ + 8.7, + 103.3 + ], + "showticklabels": false + }, + "xaxis13": { + "anchor": "y13", + "domain": [ + 0, + 0.2125 + ], + "range": [ + 8, + 536 + ], + "title": { + "text": "batch_size" + } + }, + "xaxis14": { + "anchor": "y14", + "domain": [ + 0.2625, + 0.475 + ], + "range": [ + 0.0334069404001943, + 0.18873115064668006 + ], + "title": { + "text": "dropout" + } + }, + "xaxis15": { + "anchor": "y15", + "domain": [ + 0.525, + 0.7375 + ], + "range": [ + 0.6, + 9.4 + ], + "title": { + "text": "n_layers_hidden" + } + }, + "xaxis16": { + "anchor": "y16", + "domain": [ + 0.7875, + 1 + ], + "range": [ + 8.7, + 103.3 + ], + "title": { + "text": "n_units_hidden" + } + }, + "xaxis2": { + "anchor": "y2", + "domain": [ + 0.2625, + 0.475 + ], + "matches": "x14", + "range": [ + 0.0334069404001943, + 0.18873115064668006 + ], + "showticklabels": false + }, + "xaxis3": { + "anchor": "y3", + "domain": [ + 0.525, + 0.7375 + ], + "matches": "x15", + "range": [ + 0.6, + 9.4 + ], + "showticklabels": false + }, + "xaxis4": { + "anchor": "y4", + "domain": [ + 0.7875, + 1 + ], + "matches": "x16", + "range": [ + 8.7, + 103.3 + ], + "showticklabels": false + }, + "xaxis5": { + "anchor": "y5", + "domain": [ + 0, + 0.2125 + ], + "matches": "x13", + "range": [ + 8, + 536 + ], + "showticklabels": false + }, + "xaxis6": { + "anchor": "y6", + "domain": [ + 0.2625, + 0.475 + ], + "matches": "x14", + "range": [ + 0.0334069404001943, + 0.18873115064668006 + ], + "showticklabels": false + }, + "xaxis7": { + "anchor": "y7", + "domain": [ + 0.525, + 0.7375 + ], + "matches": "x15", + "range": [ + 0.6, + 9.4 + ], + "showticklabels": false + }, + "xaxis8": { + "anchor": "y8", + "domain": [ + 0.7875, + 1 + ], + "matches": "x16", + "range": [ + 8.7, + 103.3 + ], + "showticklabels": false + }, + "xaxis9": { + "anchor": "y9", + "domain": [ + 0, + 0.2125 + ], + "matches": "x13", + "range": [ + 8, + 536 + ], + "showticklabels": false + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0.80625, + 1 + ], + "range": [ + 8, + 536 + ], + "title": { + "text": "batch_size" + } + }, + "yaxis10": { + "anchor": "x10", + "domain": [ + 0.26875, + 0.4625 + ], + "matches": "y9", + "range": [ + 0.6, + 9.4 + ], + "showticklabels": false + }, + "yaxis11": { + "anchor": "x11", + "domain": [ + 0.26875, + 0.4625 + ], + "matches": "y9", + "range": [ + 0.6, + 9.4 + ], + "showticklabels": false + }, + "yaxis12": { + "anchor": "x12", + "domain": [ + 0.26875, + 0.4625 + ], + "matches": "y9", + "range": [ + 0.6, + 9.4 + ], + "showticklabels": false + }, + "yaxis13": { + "anchor": "x13", + "domain": [ + 0, + 0.19375 + ], + "range": [ + 8.7, + 103.3 + ], + "title": { + "text": "n_units_hidden" + } + }, + "yaxis14": { + "anchor": "x14", + "domain": [ + 0, + 0.19375 + ], + "matches": "y13", + "range": [ + 8.7, + 103.3 + ], + "showticklabels": false + }, + "yaxis15": { + "anchor": "x15", + "domain": [ + 0, + 0.19375 + ], + "matches": "y13", + "range": [ + 8.7, + 103.3 + ], + "showticklabels": false + }, + "yaxis16": { + "anchor": "x16", + "domain": [ + 0, + 0.19375 + ], + "matches": "y13", + "range": [ + 8.7, + 103.3 + ], + "showticklabels": false + }, + "yaxis2": { + "anchor": "x2", + "domain": [ + 0.80625, + 1 + ], + "matches": "y", + "range": [ + 8, + 536 + ], + "showticklabels": false + }, + "yaxis3": { + "anchor": "x3", + "domain": [ + 0.80625, + 1 + ], + "matches": "y", + "range": [ + 8, + 536 + ], + "showticklabels": false + }, + "yaxis4": { + "anchor": "x4", + "domain": [ + 0.80625, + 1 + ], + "matches": "y", + "range": [ + 8, + 536 + ], + "showticklabels": false + }, + "yaxis5": { + "anchor": "x5", + "domain": [ + 0.5375, + 0.73125 + ], + "range": [ + 0.0334069404001943, + 0.18873115064668006 + ], + "title": { + "text": "dropout" + } + }, + "yaxis6": { + "anchor": "x6", + "domain": [ + 0.5375, + 0.73125 + ], + "matches": "y5", + "range": [ + 0.0334069404001943, + 0.18873115064668006 + ], + "showticklabels": false + }, + "yaxis7": { + "anchor": "x7", + "domain": [ + 0.5375, + 0.73125 + ], + "matches": "y5", + "range": [ + 0.0334069404001943, + 0.18873115064668006 + ], + "showticklabels": false + }, + "yaxis8": { + "anchor": "x8", + "domain": [ + 0.5375, + 0.73125 + ], + "matches": "y5", + "range": [ + 0.0334069404001943, + 0.18873115064668006 + ], + "showticklabels": false + }, + "yaxis9": { + "anchor": "x9", + "domain": [ + 0.26875, + 0.4625 + ], + "range": [ + 0.6, + 9.4 + ], + "title": { + "text": "n_layers_hidden" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize hyperparameter relationships.\n", + "plot_contour(study, params=['batch_size', 'dropout', 'n_layers_hidden', 'n_units_hidden'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize individual hyperparameters as slice plot.\n", + "plot_slice(study)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "cliponaxis": false, + "hovertemplate": [ + "lr (CategoricalDistribution): 0.00846909212308733", + "linear_transform_type (CategoricalDistribution): 0.017513441865308507", + "batch_norm (CategoricalDistribution): 0.03432980997949527", + "base_transform_type (CategoricalDistribution): 0.03608782621904944", + "n_units_hidden (IntUniformDistribution): 0.056248651052508904", + "batch_size (CategoricalDistribution): 0.06533755831900438", + "n_layers_hidden (IntUniformDistribution): 0.07065620910903445", + "n_iter (IntUniformDistribution): 0.253054441613576", + "dropout (UniformDistribution): 0.45830296971893575" + ], + "marker": { + "color": "rgb(66,146,198)" + }, + "orientation": "h", + "text": [ + "<0.01", + "0.02", + "0.03", + "0.04", + "0.06", + "0.07", + "0.07", + "0.25", + "0.46" + ], + "textposition": "outside", + "type": "bar", + "x": [ + 0.00846909212308733, + 0.017513441865308507, + 0.03432980997949527, + 0.03608782621904944, + 0.056248651052508904, + 0.06533755831900438, + 0.07065620910903445, + 0.253054441613576, + 0.45830296971893575 + ], + "y": [ + "lr", + "linear_transform_type", + "batch_norm", + "base_transform_type", + "n_units_hidden", + "batch_size", + "n_layers_hidden", + "n_iter", + "dropout" + ] + } + ], + "layout": { + "showlegend": false, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Hyperparameter Importances" + }, + "xaxis": { + "title": { + "text": "Importance for Objective Value" + } + }, + "yaxis": { + "title": { + "text": "Hyperparameter" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize parameter importances.\n", + "plot_param_importances(study)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "cliponaxis": false, + "hovertemplate": [ + "batch_size (CategoricalDistribution): 0.0031661764258092114", + "batch_norm (CategoricalDistribution): 0.005887013322608155", + "linear_transform_type (CategoricalDistribution): 0.007340845659453336", + "n_layers_hidden (IntUniformDistribution): 0.018588433379606542", + "lr (CategoricalDistribution): 0.03405957797841154", + "n_iter (IntUniformDistribution): 0.049231044021663506", + "dropout (UniformDistribution): 0.055752145582047774", + "base_transform_type (CategoricalDistribution): 0.1964047830096296", + "n_units_hidden (IntUniformDistribution): 0.6295699806207704" + ], + "marker": { + "color": "rgb(66,146,198)" + }, + "orientation": "h", + "text": [ + "<0.01", + "<0.01", + "<0.01", + "0.02", + "0.03", + "0.05", + "0.06", + "0.20", + "0.63" + ], + "textposition": "outside", + "type": "bar", + "x": [ + 0.0031661764258092114, + 0.005887013322608155, + 0.007340845659453336, + 0.018588433379606542, + 0.03405957797841154, + 0.049231044021663506, + 0.055752145582047774, + 0.1964047830096296, + 0.6295699806207704 + ], + "y": [ + "batch_size", + "batch_norm", + "linear_transform_type", + "n_layers_hidden", + "lr", + "n_iter", + "dropout", + "base_transform_type", + "n_units_hidden" + ] + } + ], + "layout": { + "showlegend": false, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Hyperparameter Importances" + }, + "xaxis": { + "title": { + "text": "Importance for duration" + } + }, + "yaxis": { + "title": { + "text": "Hyperparameter" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Learn which hyperparameters are affecting the trial duration with hyperparameter importance.\n", + "optuna.visualization.plot_param_importances(\n", + " study, target=lambda t: t.duration.total_seconds(), target_name=\"duration\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "mode": "lines", + "name": "no-name-ce6b3d6c-0504-44aa-a791-928abb036a1c", + "type": "scatter", + "x": [ + 0.4701086956521739, + 0.47061191626409016, + 0.4711151368760064, + 0.4716183574879227, + 0.47212157809983896, + 0.4726247987117552, + 0.4731280193236715, + 0.47363123993558776, + 0.474134460547504, + 0.4746376811594203, + 0.47514090177133655, + 0.4756441223832528, + 0.4761473429951691, + 0.47665056360708535, + 0.4771537842190016, + 0.4776570048309179, + 0.47816022544283415, + 0.4786634460547504, + 0.4791666666666667, + 0.47966988727858295, + 0.4801731078904992, + 0.4806763285024155, + 0.48117954911433175, + 0.481682769726248, + 0.4821859903381643, + 0.48268921095008055, + 0.48319243156199676, + 0.483695652173913, + 0.4841988727858293, + 0.48470209339774556, + 0.4852053140096618, + 0.4857085346215781, + 0.48621175523349436, + 0.4867149758454106, + 0.4872181964573269, + 0.48772141706924316, + 0.4882246376811594, + 0.4887278582930757, + 0.48923107890499196, + 0.4897342995169082, + 0.4902375201288245, + 0.49074074074074076, + 0.491243961352657, + 0.4917471819645733, + 0.49225040257648955, + 0.4927536231884058, + 0.4932568438003221, + 0.49376006441223835, + 0.4942632850241546, + 0.4947665056360709, + 0.49526972624798715, + 0.4957729468599034, + 0.4962761674718197, + 0.49677938808373595, + 0.4972826086956522, + 0.4977858293075685, + 0.49828904991948475, + 0.498792270531401, + 0.4992954911433173, + 0.49979871175523355, + 0.5003019323671498, + 0.500805152979066, + 0.5013083735909823, + 0.5018115942028986, + 0.5023148148148149, + 0.5028180354267311, + 0.5033212560386474, + 0.5038244766505636, + 0.50432769726248, + 0.5048309178743962, + 0.5053341384863125, + 0.5058373590982287, + 0.506340579710145, + 0.5068438003220612, + 0.5073470209339775, + 0.5078502415458938, + 0.5083534621578101, + 0.5088566827697263, + 0.5093599033816426, + 0.5098631239935588, + 0.5103663446054751, + 0.5108695652173914, + 0.5113727858293077, + 0.5118760064412239, + 0.5123792270531402, + 0.5128824476650564, + 0.5133856682769727, + 0.513888888888889, + 0.5143921095008053, + 0.5148953301127215, + 0.5153985507246378, + 0.515901771336554, + 0.5164049919484703, + 0.5169082125603865, + 0.5174114331723029, + 0.5179146537842191, + 0.5184178743961354, + 0.5189210950080516, + 0.5194243156199679, + 0.5199275362318841 + ], + "y": [ + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.8, + 0.9, + 0.9, + 1 + ] + } + ], + "layout": { + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Empirical Distribution Function Plot" + }, + "xaxis": { + "title": { + "text": "Objective Value" + } + }, + "yaxis": { + "range": [ + 0, + 1 + ], + "title": { + "text": "Cumulative Probability" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize empirical distribution function of the objective.\n", + "plot_edf(study)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test performance of the optimized plugin" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-04-07T22:13:18.269044+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:03<00:00, 30.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[4m\u001b[1mPlugin : test\u001b[0m\u001b[0m\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
minmaxmeanstddevmedianiqrroundserrorsdurations
detection.detection_xgb.mean1.0000001.0000001.0000000.01.0000000.0100.14
detection.detection_mlp.mean0.5854410.5854410.5854410.00.5854410.0101.82
\n", + "
" + ], + "text/plain": [ + " min max mean stddev median \\\n", + "detection.detection_xgb.mean 1.000000 1.000000 1.000000 0.0 1.000000 \n", + "detection.detection_mlp.mean 0.585441 0.585441 0.585441 0.0 0.585441 \n", + "\n", + " iqr rounds errors durations \n", + "detection.detection_xgb.mean 0.0 1 0 0.14 \n", + "detection.detection_mlp.mean 0.0 1 0 1.82 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "best_params = study.best_params\n", + "best_params['n_iter'] = 100\n", + "report = Benchmarks.evaluate(\n", + " [(\"test\", PLUGIN, best_params)],\n", + " train_loader,\n", + " test_loader,\n", + " repeats=1,\n", + " metrics={\"detection\": [\"detection_mlp\", \"detection_xgb\"]}, # DELETE THIS LINE FOR ALL METRICS\n", + ")\n", + "Benchmarks.print(report)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Congratulations!\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!\n", + "\n", + "### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub\n", + "\n", + "- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.\n", + "\n", + "\n", + "### Checkout other projects from vanderschaarlab\n", + "- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)\n", + "- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "synthcity", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 2fb8508939d583acce06a05a77a599773d469842 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sat, 8 Apr 2023 19:14:54 +0200 Subject: [PATCH 59/95] update --- tests/plugins/generic/test_ddpm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/plugins/generic/test_ddpm.py b/tests/plugins/generic/test_ddpm.py index ae0462d8..0ba32d81 100644 --- a/tests/plugins/generic/test_ddpm.py +++ b/tests/plugins/generic/test_ddpm.py @@ -18,11 +18,10 @@ plugin_name = "ddpm" plugin_params = dict( - n_iter=500, + n_iter=1000, batch_size=1000, num_timesteps=100, model_type="mlp", - sampling_patience=100, ) @@ -122,7 +121,7 @@ def test_plugin_hyperparams(test_plugin: Plugin) -> None: def test_sample_hyperparams() -> None: - for i in range(100): + for _ in range(100): args = plugin.sample_hyperparameters() assert plugin(**args) is not None From 5adfabfd435c33cb004ed9bd0ab5a56d070f895d Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sat, 8 Apr 2023 19:16:13 +0200 Subject: [PATCH 60/95] Fix plugin type of static_model of fflows --- src/synthcity/plugins/time_series/plugin_fflows.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/synthcity/plugins/time_series/plugin_fflows.py b/src/synthcity/plugins/time_series/plugin_fflows.py index d357ea24..d62557c0 100644 --- a/src/synthcity/plugins/time_series/plugin_fflows.py +++ b/src/synthcity/plugins/time_series/plugin_fflows.py @@ -11,6 +11,7 @@ from fflows import FourierFlow # synthcity absolute +from synthcity.plugins import Plugins from synthcity.plugins.core.dataloader import DataLoader from synthcity.plugins.core.distribution import ( CategoricalDistribution, @@ -24,7 +25,6 @@ from synthcity.plugins.core.models.ts_model import TimeSeriesModel from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema -from synthcity.plugins.generic import GenericPlugins from synthcity.utils.constants import DEVICE @@ -134,9 +134,7 @@ def __init__( normalize=normalize, ).to(device) - self.static_model = GenericPlugins().get( - self.static_model_name, device=self.device - ) + self.static_model = Plugins().get(self.static_model_name, device=self.device) self.temporal_encoder = TimeSeriesTabularEncoder( max_clusters=encoder_max_clusters From a2a88c51fd0bf33d70d5e5f3ce7cfeb75a73ead3 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sat, 8 Apr 2023 21:39:13 +0200 Subject: [PATCH 61/95] update intlogdistribution and tutorial --- src/synthcity/plugins/core/distribution.py | 12 +- src/synthcity/utils/optuna_sample.py | 6 - ...utorial8_hyperparameter_optimization.ipynb | 4316 +++++++++++------ 3 files changed, 2828 insertions(+), 1506 deletions(-) diff --git a/src/synthcity/plugins/core/distribution.py b/src/synthcity/plugins/core/distribution.py index 06b3a99b..9934539c 100644 --- a/src/synthcity/plugins/core/distribution.py +++ b/src/synthcity/plugins/core/distribution.py @@ -353,25 +353,23 @@ def dtype(self) -> str: class IntLogDistribution(IntegerDistribution): low: int = 1 high: int = np.iinfo(np.int64).max - step: int = 2 # the next sample larger than x is step * x @validator("step", always=True) def _validate_step(cls: Any, v: int, values: Dict) -> int: - if v < 2: - raise ValueError("Step must be greater than 1") + if v != 1: + raise ValueError("Step must be 1 for IntLogDistribution") return v def get(self) -> List[Any]: - return [self.name, self.low, self.high, self.step] + return [self.name, self.low, self.high] def sample(self, count: int = 1) -> Any: np.random.seed(self.random_state) msamples = self.sample_marginal(count) if msamples is not None: return msamples - steps = int(np.log2(self.high / self.low) / np.log2(self.step)) - samples = np.random.choice(steps + 1, count) - samples = self.low * self.step**samples + lo, hi = np.log2(self.low), np.log2(self.high) + samples = 2.0 ** np.random.uniform(lo, hi, count) return samples.astype(int) diff --git a/src/synthcity/utils/optuna_sample.py b/src/synthcity/utils/optuna_sample.py index c19dae66..87b7aafc 100644 --- a/src/synthcity/utils/optuna_sample.py +++ b/src/synthcity/utils/optuna_sample.py @@ -16,15 +16,9 @@ def suggest(trial: optuna.Trial, dist: D.Distribution) -> Any: elif isinstance(dist, D.IntegerDistribution): return trial.suggest_int(dist.name, dist.low, dist.high, dist.step) elif isinstance(dist, D.IntLogDistribution): - # ! does not handle step yet return trial.suggest_int(dist.name, dist.low, dist.high, log=True) elif isinstance(dist, D.CategoricalDistribution): return trial.suggest_categorical(dist.name, dist.choices) - # ! the modification cannot be reflected in study.best_params - # elif isinstance(dist, D.DatetimeDistribution): - # high = (dist.high - dist.low) / dist.step - # s = trial.suggest_float(dist.name, 0, high) - # return dist.low + dist.step * s else: raise ValueError(f"Unknown dist: {dist}") diff --git a/tutorials/tutorial8_hyperparameter_optimization.ipynb b/tutorials/tutorial8_hyperparameter_optimization.ipynb index 4cf7c965..f95c1e0c 100644 --- a/tutorials/tutorial8_hyperparameter_optimization.ipynb +++ b/tutorials/tutorial8_hyperparameter_optimization.ipynb @@ -298,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -320,14 +320,15 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[2023-04-07T21:51:56.689921+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n" + "[2023-04-08T20:56:15.722354+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-08T20:56:15.722354+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n" ] }, { @@ -336,7 +337,7 @@ "synthcity.plugins.generic.plugin_nflow.NormalizingFlowsPlugin" ] }, - "execution_count": 50, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -357,7 +358,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -368,13 +369,13 @@ " IntegerDistribution(name='n_units_hidden', data=None, random_state=0, marginal_distribution=None, low=10, high=100, step=1),\n", " CategoricalDistribution(name='batch_size', data=None, random_state=0, marginal_distribution=None, choices=[32, 64, 128, 256, 512]),\n", " FloatDistribution(name='dropout', data=None, random_state=0, marginal_distribution=None, low=0.0, high=0.2),\n", - " CategoricalDistribution(name='batch_norm', data=None, random_state=0, marginal_distribution=None, choices=[True, False]),\n", - " CategoricalDistribution(name='lr', data=None, random_state=0, marginal_distribution=None, choices=[0.001, 0.0001, 0.0002]),\n", + " CategoricalDistribution(name='batch_norm', data=None, random_state=0, marginal_distribution=None, choices=[False, True]),\n", + " CategoricalDistribution(name='lr', data=None, random_state=0, marginal_distribution=None, choices=[0.0001, 0.0002, 0.001]),\n", " CategoricalDistribution(name='linear_transform_type', data=None, random_state=0, marginal_distribution=None, choices=['lu', 'permutation', 'svd']),\n", - " CategoricalDistribution(name='base_transform_type', data=None, random_state=0, marginal_distribution=None, choices=['affine-coupling', 'quadratic-coupling', 'rq-coupling', 'affine-autoregressive', 'quadratic-autoregressive', 'rq-autoregressive'])]" + " CategoricalDistribution(name='base_transform_type', data=None, random_state=0, marginal_distribution=None, choices=['affine-autoregressive', 'affine-coupling', 'quadratic-autoregressive', 'quadratic-coupling', 'rq-autoregressive', 'rq-coupling'])]" ] }, - "execution_count": 51, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -393,24 +394,24 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'n_iter': 486,\n", - " 'n_layers_hidden': 10,\n", + "{'n_iter': 100,\n", + " 'n_layers_hidden': 1,\n", " 'n_units_hidden': 87,\n", - " 'batch_size': 512,\n", - " 'dropout': 0.016022465975681178,\n", - " 'batch_norm': True,\n", + " 'batch_size': 256,\n", + " 'dropout': 0.15424246144819787,\n", + " 'batch_norm': False,\n", " 'lr': 0.001,\n", " 'linear_transform_type': 'svd',\n", - " 'base_transform_type': 'affine-coupling'}" + " 'base_transform_type': 'rq-autoregressive'}" ] }, - "execution_count": 52, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -420,6 +421,7 @@ "\n", "trial = optuna.create_study().ask()\n", "params = suggest_all(trial, plugin_cls.hyperparameter_space())\n", + "params['n_iter'] = 100 # speed up\n", "params" ] }, @@ -433,16 +435,16 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - " 62%|██████▏ | 299/486 [01:26<00:53, 3.47it/s]\n", - "[2023-04-07T21:53:29.785866+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - " 62%|██████▏ | 299/486 [01:30<00:56, 3.31it/s]\n" + "100%|██████████| 100/100 [00:38<00:00, 2.56it/s]\n", + "[2023-04-08T20:57:54.561757+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:30<00:00, 3.24it/s]\n" ] }, { @@ -481,15 +483,15 @@ " \n", " \n", " detection.detection_mlp.mean\n", - " 0.390071\n", - " 0.390071\n", - " 0.390071\n", + " 0.5\n", + " 0.5\n", + " 0.5\n", " 0.0\n", - " 0.390071\n", + " 0.5\n", " 0.0\n", " 1\n", " 0\n", - " 2.51\n", + " 5.95\n", " minimize\n", " \n", " \n", @@ -497,14 +499,14 @@ "" ], "text/plain": [ - " min max mean stddev median \\\n", - "detection.detection_mlp.mean 0.390071 0.390071 0.390071 0.0 0.390071 \n", + " min max mean stddev median iqr rounds \\\n", + "detection.detection_mlp.mean 0.5 0.5 0.5 0.0 0.5 0.0 1 \n", "\n", - " iqr rounds errors durations direction \n", - "detection.detection_mlp.mean 0.0 1 0 2.51 minimize " + " errors durations direction \n", + "detection.detection_mlp.mean 0 5.95 minimize " ] }, - "execution_count": 53, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -532,60 +534,102 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[2023-04-07T21:57:37.669090+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T21:57:37.689827+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:12<00:00, 8.05it/s]\n", - "[2023-04-07T21:57:53.690237+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T21:57:53.712601+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:04<00:00, 24.77it/s]\n", - "[2023-04-07T21:58:01.728358+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T21:58:01.744010+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:24<00:00, 4.08it/s]\n", - "[2023-04-07T21:58:32.292499+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T21:58:32.316002+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [01:30<00:00, 1.10it/s]\n", - "[2023-04-07T22:00:38.652411+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T22:00:38.685914+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:23<00:00, 4.21it/s]\n", - "[2023-04-07T22:01:09.148491+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T22:01:09.178259+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:06<00:00, 14.79it/s]\n", - "[2023-04-07T22:01:20.722191+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T22:01:20.751419+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [02:00<00:00, 1.21s/it]\n", - "[2023-04-07T22:03:29.180475+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T22:03:29.211421+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [01:37<00:00, 1.02it/s]\n", - "[2023-04-07T22:05:12.012437+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T22:05:12.030781+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:57<00:00, 1.74it/s]\n", - "[2023-04-07T22:06:14.408112+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-07T22:06:14.431469+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:11<00:00, 8.47it/s]\n" + "[2023-04-08T21:26:16.278633+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-08T21:26:16.301778+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:39<00:00, 2.56it/s]\n", + "[2023-04-08T21:26:59.665597+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-08T21:26:59.684496+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:26<00:00, 3.74it/s]\n", + "[2023-04-08T21:27:30.475951+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-08T21:27:30.495645+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:04<00:00, 21.72it/s]\n", + "[2023-04-08T21:27:39.102805+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-08T21:27:39.117858+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [03:07<00:00, 1.88s/it]\n", + "[2023-04-08T21:31:35.546758+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "[2023-04-08T21:31:35.638203+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + " 0%| | 0/100 [00:00", - "linear_transform_type (CategoricalDistribution): 0.017513441865308507", - "batch_norm (CategoricalDistribution): 0.03432980997949527", - "base_transform_type (CategoricalDistribution): 0.03608782621904944", - "n_units_hidden (IntUniformDistribution): 0.056248651052508904", - "batch_size (CategoricalDistribution): 0.06533755831900438", - "n_layers_hidden (IntUniformDistribution): 0.07065620910903445", - "n_iter (IntUniformDistribution): 0.253054441613576", - "dropout (UniformDistribution): 0.45830296971893575" + "lr (CategoricalDistribution): 0.0", + "n_iter (IntDistribution): 0.0", + "n_layers_hidden (IntDistribution): 7.256702823093135e-31", + "batch_norm (CategoricalDistribution): 0.02499999999999996", + "n_units_hidden (IntDistribution): 0.025000000000000088", + "batch_size (CategoricalDistribution): 0.04999999999999992", + "linear_transform_type (CategoricalDistribution): 0.09999999999999984", + "dropout (FloatDistribution): 0.275000000000001", + "base_transform_type (CategoricalDistribution): 0.5249999999999992" ], "marker": { "color": "rgb(66,146,198)" }, "orientation": "h", "text": [ + "<0.01", + "<0.01", "<0.01", "0.02", "0.03", - "0.04", - "0.06", - "0.07", - "0.07", - "0.25", - "0.46" + "0.05", + "0.10", + "0.28", + "0.52" ], "textposition": "outside", "type": "bar", "x": [ - 0.00846909212308733, - 0.017513441865308507, - 0.03432980997949527, - 0.03608782621904944, - 0.056248651052508904, - 0.06533755831900438, - 0.07065620910903445, - 0.253054441613576, - 0.45830296971893575 + 0, + 0, + 7.256702823093135e-31, + 0.02499999999999996, + 0.025000000000000088, + 0.04999999999999992, + 0.09999999999999984, + 0.275000000000001, + 0.5249999999999992 ], "y": [ "lr", - "linear_transform_type", + "n_iter", + "n_layers_hidden", "batch_norm", - "base_transform_type", "n_units_hidden", "batch_size", - "n_layers_hidden", - "n_iter", - "dropout" + "linear_transform_type", + "dropout", + "base_transform_type" ] } ], @@ -7716,7 +9047,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -7729,15 +9060,15 @@ { "cliponaxis": false, "hovertemplate": [ - "batch_size (CategoricalDistribution): 0.0031661764258092114", - "batch_norm (CategoricalDistribution): 0.005887013322608155", - "linear_transform_type (CategoricalDistribution): 0.007340845659453336", - "n_layers_hidden (IntUniformDistribution): 0.018588433379606542", - "lr (CategoricalDistribution): 0.03405957797841154", - "n_iter (IntUniformDistribution): 0.049231044021663506", - "dropout (UniformDistribution): 0.055752145582047774", - "base_transform_type (CategoricalDistribution): 0.1964047830096296", - "n_units_hidden (IntUniformDistribution): 0.6295699806207704" + "n_iter (IntDistribution): 0.0", + "batch_norm (CategoricalDistribution): 0.0008792721126079252", + "n_units_hidden (IntDistribution): 0.050652923907431195", + "lr (CategoricalDistribution): 0.07515326418808736", + "linear_transform_type (CategoricalDistribution): 0.08234393383908772", + "batch_size (CategoricalDistribution): 0.1506171783782107", + "dropout (FloatDistribution): 0.1928283779305551", + "n_layers_hidden (IntDistribution): 0.20147707299584372", + "base_transform_type (CategoricalDistribution): 0.2460479766481761" ], "marker": { "color": "rgb(66,146,198)" @@ -7746,37 +9077,37 @@ "text": [ "<0.01", "<0.01", - "<0.01", - "0.02", - "0.03", "0.05", - "0.06", + "0.08", + "0.08", + "0.15", + "0.19", "0.20", - "0.63" + "0.25" ], "textposition": "outside", "type": "bar", "x": [ - 0.0031661764258092114, - 0.005887013322608155, - 0.007340845659453336, - 0.018588433379606542, - 0.03405957797841154, - 0.049231044021663506, - 0.055752145582047774, - 0.1964047830096296, - 0.6295699806207704 + 0, + 0.0008792721126079252, + 0.050652923907431195, + 0.07515326418808736, + 0.08234393383908772, + 0.1506171783782107, + 0.1928283779305551, + 0.20147707299584372, + 0.2460479766481761 ], "y": [ - "batch_size", + "n_iter", "batch_norm", - "linear_transform_type", - "n_layers_hidden", + "n_units_hidden", "lr", - "n_iter", + "linear_transform_type", + "batch_size", "dropout", - "base_transform_type", - "n_units_hidden" + "n_layers_hidden", + "base_transform_type" ] } ], @@ -8627,7 +9958,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -8639,210 +9970,210 @@ "data": [ { "mode": "lines", - "name": "no-name-ce6b3d6c-0504-44aa-a791-928abb036a1c", + "name": "no-name-13d7c589-c089-4e41-befb-9c6abf683ae9", "type": "scatter", "x": [ - 0.4701086956521739, - 0.47061191626409016, - 0.4711151368760064, - 0.4716183574879227, - 0.47212157809983896, - 0.4726247987117552, - 0.4731280193236715, - 0.47363123993558776, - 0.474134460547504, - 0.4746376811594203, - 0.47514090177133655, - 0.4756441223832528, - 0.4761473429951691, - 0.47665056360708535, - 0.4771537842190016, - 0.4776570048309179, - 0.47816022544283415, - 0.4786634460547504, - 0.4791666666666667, - 0.47966988727858295, - 0.4801731078904992, - 0.4806763285024155, - 0.48117954911433175, - 0.481682769726248, - 0.4821859903381643, - 0.48268921095008055, - 0.48319243156199676, - 0.483695652173913, - 0.4841988727858293, - 0.48470209339774556, - 0.4852053140096618, - 0.4857085346215781, - 0.48621175523349436, - 0.4867149758454106, - 0.4872181964573269, - 0.48772141706924316, - 0.4882246376811594, - 0.4887278582930757, - 0.48923107890499196, - 0.4897342995169082, - 0.4902375201288245, - 0.49074074074074076, - 0.491243961352657, - 0.4917471819645733, - 0.49225040257648955, - 0.4927536231884058, - 0.4932568438003221, - 0.49376006441223835, - 0.4942632850241546, - 0.4947665056360709, - 0.49526972624798715, - 0.4957729468599034, - 0.4962761674718197, - 0.49677938808373595, - 0.4972826086956522, - 0.4977858293075685, - 0.49828904991948475, - 0.498792270531401, - 0.4992954911433173, - 0.49979871175523355, - 0.5003019323671498, - 0.500805152979066, - 0.5013083735909823, - 0.5018115942028986, - 0.5023148148148149, - 0.5028180354267311, - 0.5033212560386474, - 0.5038244766505636, - 0.50432769726248, - 0.5048309178743962, - 0.5053341384863125, - 0.5058373590982287, - 0.506340579710145, - 0.5068438003220612, - 0.5073470209339775, - 0.5078502415458938, - 0.5083534621578101, - 0.5088566827697263, - 0.5093599033816426, - 0.5098631239935588, - 0.5103663446054751, - 0.5108695652173914, - 0.5113727858293077, - 0.5118760064412239, - 0.5123792270531402, - 0.5128824476650564, - 0.5133856682769727, - 0.513888888888889, - 0.5143921095008053, - 0.5148953301127215, - 0.5153985507246378, - 0.515901771336554, - 0.5164049919484703, - 0.5169082125603865, - 0.5174114331723029, - 0.5179146537842191, - 0.5184178743961354, - 0.5189210950080516, - 0.5194243156199679, - 0.5199275362318841 + 0.4788647342995169, + 0.47907822183184506, + 0.47929170936417315, + 0.4795051968965013, + 0.47971868442882937, + 0.4799321719611575, + 0.4801456594934856, + 0.4803591470258137, + 0.4805726345581418, + 0.48078612209046995, + 0.48099960962279803, + 0.48121309715512617, + 0.48142658468745425, + 0.4816400722197824, + 0.48185355975211047, + 0.4820670472844386, + 0.4822805348167667, + 0.48249402234909483, + 0.4827075098814229, + 0.48292099741375105, + 0.48313448494607913, + 0.4833479724784073, + 0.4835614600107354, + 0.4837749475430635, + 0.48398843507539163, + 0.4842019226077197, + 0.48441541014004785, + 0.48462889767237594, + 0.4848423852047041, + 0.48505587273703216, + 0.4852693602693603, + 0.4854828478016884, + 0.4856963353340165, + 0.4859098228663446, + 0.48612331039867274, + 0.4863367979310008, + 0.48655028546332896, + 0.48676377299565704, + 0.4869772605279852, + 0.48719074806031326, + 0.4874042355926414, + 0.48761772312496954, + 0.4878312106572976, + 0.48804469818962576, + 0.48825818572195384, + 0.488471673254282, + 0.48868516078661006, + 0.4888986483189382, + 0.4891121358512663, + 0.4893256233835944, + 0.4895391109159225, + 0.48975259844825064, + 0.4899660859805787, + 0.49017957351290686, + 0.49039306104523495, + 0.4906065485775631, + 0.49082003610989117, + 0.4910335236422193, + 0.4912470111745474, + 0.4914604987068755, + 0.49167398623920366, + 0.49188747377153175, + 0.4921009613038599, + 0.49231444883618797, + 0.4925279363685161, + 0.4927414239008442, + 0.4929549114331723, + 0.4931683989655004, + 0.49338188649782855, + 0.49359537403015663, + 0.49380886156248477, + 0.49402234909481285, + 0.494235836627141, + 0.4944493241594691, + 0.4946628116917972, + 0.4948762992241253, + 0.49508978675645343, + 0.4953032742887815, + 0.49551676182110965, + 0.4957302493534378, + 0.4959437368857659, + 0.496157224418094, + 0.4963707119504221, + 0.49658419948275023, + 0.4967976870150783, + 0.49701117454740645, + 0.49722466207973454, + 0.4974381496120627, + 0.49765163714439076, + 0.4978651246767189, + 0.498078612209047, + 0.4982920997413751, + 0.4985055872737032, + 0.49871907480603134, + 0.4989325623383594, + 0.49914604987068756, + 0.49935953740301564, + 0.4995730249353438, + 0.4997865124676719, + 0.5 ], "y": [ - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.2, - 0.2, - 0.2, - 0.2, - 0.2, - 0.2, - 0.2, - 0.2, - 0.2, - 0.2, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.8, - 0.9, - 0.9, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, + 0.14285714285714285, 1 ] } @@ -9703,15 +11034,15 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[2023-04-07T22:13:18.269044+0200][4048][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:03<00:00, 30.59it/s]\n" + "[2023-04-08T21:36:39.947037+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", + "100%|██████████| 100/100 [00:20<00:00, 4.87it/s]\n" ] }, { @@ -9757,27 +11088,27 @@ " \n", " \n", " detection.detection_xgb.mean\n", - " 1.000000\n", - " 1.000000\n", - " 1.000000\n", + " 0.988506\n", + " 0.988506\n", + " 0.988506\n", " 0.0\n", - " 1.000000\n", + " 0.988506\n", " 0.0\n", " 1\n", " 0\n", - " 0.14\n", + " 0.18\n", " \n", " \n", " detection.detection_mlp.mean\n", - " 0.585441\n", - " 0.585441\n", - " 0.585441\n", + " 0.703640\n", + " 0.703640\n", + " 0.703640\n", " 0.0\n", - " 0.585441\n", + " 0.703640\n", " 0.0\n", " 1\n", " 0\n", - " 1.82\n", + " 3.22\n", " \n", " \n", "\n", @@ -9785,12 +11116,12 @@ ], "text/plain": [ " min max mean stddev median \\\n", - "detection.detection_xgb.mean 1.000000 1.000000 1.000000 0.0 1.000000 \n", - "detection.detection_mlp.mean 0.585441 0.585441 0.585441 0.0 0.585441 \n", + "detection.detection_xgb.mean 0.988506 0.988506 0.988506 0.0 0.988506 \n", + "detection.detection_mlp.mean 0.703640 0.703640 0.703640 0.0 0.703640 \n", "\n", " iqr rounds errors durations \n", - "detection.detection_xgb.mean 0.0 1 0 0.14 \n", - "detection.detection_mlp.mean 0.0 1 0 1.82 " + "detection.detection_xgb.mean 0.0 1 0 0.18 \n", + "detection.detection_mlp.mean 0.0 1 0 3.22 " ] }, "metadata": {}, @@ -9806,7 +11137,6 @@ ], "source": [ "best_params = study.best_params\n", - "best_params['n_iter'] = 100\n", "report = Benchmarks.evaluate(\n", " [(\"test\", PLUGIN, best_params)],\n", " train_loader,\n", From 4a7e73bfca1c99db5ea0f1432e47e6d516556453 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sat, 8 Apr 2023 22:49:02 +0200 Subject: [PATCH 62/95] try fixing goggle --- src/synthcity/plugins/core/models/convnet.py | 31 +++++--------------- src/synthcity/plugins/core/models/goggle.py | 7 +++-- tests/plugins/core/models/test_convnet.py | 4 +-- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/synthcity/plugins/core/models/convnet.py b/src/synthcity/plugins/core/models/convnet.py index ae4260e6..a80bbb12 100644 --- a/src/synthcity/plugins/core/models/convnet.py +++ b/src/synthcity/plugins/core/models/convnet.py @@ -4,36 +4,19 @@ # third party import numpy as np import torch -from monai.networks.layers.factories import Act + +# from monai.networks.layers.factories import Act from monai.networks.nets import Classifier, Discriminator, Generator from pydantic import validate_arguments from torch import nn # synthcity absolute import synthcity.logger as log +from synthcity.plugins.core.models.factory import get_nonlin from synthcity.utils.constants import DEVICE from synthcity.utils.reproducibility import enable_reproducible_results -def map_nonlin(nonlin: str) -> Act: - if nonlin == "relu": - return Act.RELU - elif nonlin == "elu": - return Act.ELU - elif nonlin == "prelu": - return Act.PRELU - elif nonlin == "leaky_relu": - return Act.LEAKYRELU - elif nonlin == "sigmoid": - return Act.SIGMOID - elif nonlin == "softmax": - return Act.SOFTMAX - elif nonlin == "tanh": - return Act.TANH - - raise ValueError(f"Unknown activation {nonlin}") - - class ConvNet(nn.Module): """ Wrapper for convolutional nets for classification and regression. @@ -437,7 +420,7 @@ def suggest_image_generator_discriminator_arch( strides=[2, 2, 2, 1], kernel_size=3, dropout=generator_dropout, - act=map_nonlin(generator_nonlin), + act=get_nonlin(generator_nonlin), num_res_units=generator_n_residual_units, ), nn.Tanh(), @@ -449,7 +432,7 @@ def suggest_image_generator_discriminator_arch( kernel_size=3, last_act=None, dropout=discriminator_dropout, - act=map_nonlin(generator_nonlin), + act=get_nonlin(generator_nonlin), num_res_units=discriminator_n_residual_units, ).to(device) @@ -559,8 +542,8 @@ def suggest_image_classifier_arch( classes=classes, channels=[16, 32, 64, 1], strides=[start_stride, 2, 2, 2], - act=map_nonlin(nonlin), - last_act=map_nonlin(last_nonlin), + act=get_nonlin(nonlin), + last_act=get_nonlin(last_nonlin), dropout=dropout, num_res_units=n_residual_units, ).to(device) diff --git a/src/synthcity/plugins/core/models/goggle.py b/src/synthcity/plugins/core/models/goggle.py index bd498507..8d35a720 100644 --- a/src/synthcity/plugins/core/models/goggle.py +++ b/src/synthcity/plugins/core/models/goggle.py @@ -19,11 +19,14 @@ # synthcity absolute import synthcity.logger as log from synthcity.plugins.core.dataloader import DataLoader -from synthcity.plugins.core.models.mlp import MultiActivationHead, get_nonlin -from synthcity.plugins.core.models.RGCNConv import RGCNConv from synthcity.utils.constants import DEVICE from synthcity.utils.reproducibility import clear_cache, enable_reproducible_results +# synthcity relative +from .factory import get_nonlin +from .layers import MultiActivationHead +from .RGCNConv import RGCNConv + class Goggle(nn.Module): @validate_arguments(config=dict(arbitrary_types_allowed=True)) diff --git a/tests/plugins/core/models/test_convnet.py b/tests/plugins/core/models/test_convnet.py index 71399626..a6659dec 100644 --- a/tests/plugins/core/models/test_convnet.py +++ b/tests/plugins/core/models/test_convnet.py @@ -7,7 +7,7 @@ # synthcity absolute from synthcity.plugins.core.models.convnet import ( - map_nonlin, + get_nonlin, suggest_image_classifier_arch, suggest_image_generator_discriminator_arch, ) @@ -16,7 +16,7 @@ @pytest.mark.parametrize("nonlin", ["relu", "elu", "prelu", "leaky_relu"]) def test_get_nonlin(nonlin: str) -> None: - assert map_nonlin(nonlin) is not None + assert get_nonlin(nonlin) is not None @pytest.mark.parametrize("n_channels", [1, 3]) From 8051caa9f080570efec9bbd9c1c53c08afdba64c Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sat, 8 Apr 2023 23:33:53 +0200 Subject: [PATCH 63/95] add more activations --- src/synthcity/plugins/core/models/convnet.py | 2 -- src/synthcity/plugins/core/models/factory.py | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/synthcity/plugins/core/models/convnet.py b/src/synthcity/plugins/core/models/convnet.py index a80bbb12..67065705 100644 --- a/src/synthcity/plugins/core/models/convnet.py +++ b/src/synthcity/plugins/core/models/convnet.py @@ -4,8 +4,6 @@ # third party import numpy as np import torch - -# from monai.networks.layers.factories import Act from monai.networks.nets import Classifier, Discriminator, Generator from pydantic import validate_arguments from torch import nn diff --git a/src/synthcity/plugins/core/models/factory.py b/src/synthcity/plugins/core/models/factory.py index a23ffc06..4de2980b 100644 --- a/src/synthcity/plugins/core/models/factory.py +++ b/src/synthcity/plugins/core/models/factory.py @@ -66,7 +66,10 @@ relu6=nn.ReLU6, celu=nn.CELU, glu=nn.GLU, + prelu=nn.PReLU, + relu6=nn.ReLU6, logsigmoid=nn.LogSigmoid, + logsoftmax=nn.LogSoftmax, softplus=nn.Softplus, ) From 3cd9917409433401bb0ca1ba69d150a72c846c0b Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 9 Apr 2023 00:28:25 +0200 Subject: [PATCH 64/95] minor fix --- src/synthcity/plugins/core/models/factory.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/synthcity/plugins/core/models/factory.py b/src/synthcity/plugins/core/models/factory.py index 4de2980b..a34897c3 100644 --- a/src/synthcity/plugins/core/models/factory.py +++ b/src/synthcity/plugins/core/models/factory.py @@ -28,7 +28,6 @@ ) from .layers import GumbelSoftmax -# should only contain nn modules that can be used as building blocks in larger models MODELS = dict( mlp=".mlp.MLP", # attention models @@ -63,7 +62,6 @@ silu=nn.SiLU, swish=nn.SiLU, hardtanh=nn.Hardtanh, - relu6=nn.ReLU6, celu=nn.CELU, glu=nn.GLU, prelu=nn.PReLU, @@ -117,6 +115,7 @@ def get_model(block: Union[str, type], params: dict) -> Any: Named models: - mlp - rnn + - gru - lstm - transformer - tabnet From 42cbe8c34070bb4f742ae61682febc4b977c9a62 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 9 Apr 2023 09:45:11 +0200 Subject: [PATCH 65/95] update --- src/synthcity/plugins/core/models/convnet.py | 29 ++++++++++++++++---- tests/plugins/core/models/test_convnet.py | 4 +-- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/synthcity/plugins/core/models/convnet.py b/src/synthcity/plugins/core/models/convnet.py index 67065705..ae4260e6 100644 --- a/src/synthcity/plugins/core/models/convnet.py +++ b/src/synthcity/plugins/core/models/convnet.py @@ -4,17 +4,36 @@ # third party import numpy as np import torch +from monai.networks.layers.factories import Act from monai.networks.nets import Classifier, Discriminator, Generator from pydantic import validate_arguments from torch import nn # synthcity absolute import synthcity.logger as log -from synthcity.plugins.core.models.factory import get_nonlin from synthcity.utils.constants import DEVICE from synthcity.utils.reproducibility import enable_reproducible_results +def map_nonlin(nonlin: str) -> Act: + if nonlin == "relu": + return Act.RELU + elif nonlin == "elu": + return Act.ELU + elif nonlin == "prelu": + return Act.PRELU + elif nonlin == "leaky_relu": + return Act.LEAKYRELU + elif nonlin == "sigmoid": + return Act.SIGMOID + elif nonlin == "softmax": + return Act.SOFTMAX + elif nonlin == "tanh": + return Act.TANH + + raise ValueError(f"Unknown activation {nonlin}") + + class ConvNet(nn.Module): """ Wrapper for convolutional nets for classification and regression. @@ -418,7 +437,7 @@ def suggest_image_generator_discriminator_arch( strides=[2, 2, 2, 1], kernel_size=3, dropout=generator_dropout, - act=get_nonlin(generator_nonlin), + act=map_nonlin(generator_nonlin), num_res_units=generator_n_residual_units, ), nn.Tanh(), @@ -430,7 +449,7 @@ def suggest_image_generator_discriminator_arch( kernel_size=3, last_act=None, dropout=discriminator_dropout, - act=get_nonlin(generator_nonlin), + act=map_nonlin(generator_nonlin), num_res_units=discriminator_n_residual_units, ).to(device) @@ -540,8 +559,8 @@ def suggest_image_classifier_arch( classes=classes, channels=[16, 32, 64, 1], strides=[start_stride, 2, 2, 2], - act=get_nonlin(nonlin), - last_act=get_nonlin(last_nonlin), + act=map_nonlin(nonlin), + last_act=map_nonlin(last_nonlin), dropout=dropout, num_res_units=n_residual_units, ).to(device) diff --git a/tests/plugins/core/models/test_convnet.py b/tests/plugins/core/models/test_convnet.py index a6659dec..71399626 100644 --- a/tests/plugins/core/models/test_convnet.py +++ b/tests/plugins/core/models/test_convnet.py @@ -7,7 +7,7 @@ # synthcity absolute from synthcity.plugins.core.models.convnet import ( - get_nonlin, + map_nonlin, suggest_image_classifier_arch, suggest_image_generator_discriminator_arch, ) @@ -16,7 +16,7 @@ @pytest.mark.parametrize("nonlin", ["relu", "elu", "prelu", "leaky_relu"]) def test_get_nonlin(nonlin: str) -> None: - assert get_nonlin(nonlin) is not None + assert map_nonlin(nonlin) is not None @pytest.mark.parametrize("n_channels", [1, 3]) From 7c58f2d3e2f90b0b79dcf3fb313056b8e2656eb4 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 9 Apr 2023 18:21:17 +0200 Subject: [PATCH 66/95] update --- tests/plugins/generic/test_goggle.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 9b58ac4e..2c9b5f4a 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -106,9 +106,6 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: assert (X_gen1.numpy() != X_gen3.numpy()).any() -is_missing_goggle_deps = True - - @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", From 104e3a39ef7d7d11a93d8c920781423eb6251088 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 9 Apr 2023 18:22:05 +0200 Subject: [PATCH 67/95] update --- tests/plugins/generic/test_goggle.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 9b58ac4e..2c9b5f4a 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -106,9 +106,6 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: assert (X_gen1.numpy() != X_gen3.numpy()).any() -is_missing_goggle_deps = True - - @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", From 7b4e04ad4552980482a74096127357851ef8ebbe Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Sun, 9 Apr 2023 21:03:02 +0200 Subject: [PATCH 68/95] update --- src/synthcity/plugins/core/constraints.py | 4 +++- tests/plugins/generic/test_goggle.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/synthcity/plugins/core/constraints.py b/src/synthcity/plugins/core/constraints.py index dc79e56b..be1b2cc8 100644 --- a/src/synthcity/plugins/core/constraints.py +++ b/src/synthcity/plugins/core/constraints.py @@ -164,9 +164,11 @@ def filter(self, X: pd.DataFrame) -> pd.DataFrame: thresh, ) if res.sum() < prev: - log.info( + log.critical( f"[{feature}] quality loss for constraints {op} = {thresh}. Remaining {res.sum()}. prev length {prev}. Original dtype {X[feature].dtype}.", ) + if res.sum() < 5: + log.critical(str(X[~res])) return res @validate_arguments(config=dict(arbitrary_types_allowed=True)) diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 2c9b5f4a..9b194ae0 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -17,7 +17,7 @@ plugin_name = "goggle" plugin_args = { - "n_iter": 500, + "n_iter": 10, "device": "cpu", } From fede549c81db710193c5d1c65128003c8cdd0937 Mon Sep 17 00:00:00 2001 From: Tianzhang Cai <13818704679@163.com> Date: Mon, 10 Apr 2023 01:17:42 +0100 Subject: [PATCH 69/95] Update tabular_encoder.py --- src/synthcity/plugins/core/models/tabular_encoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 364a6b57..fe82aca4 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -286,6 +286,7 @@ def activation_layout( d = 0 d += 1 out.append((acts[ct], d)) + log.info(out) return out From 539effaac7853a78a6de3cf92e2a9a08d17fd3a6 Mon Sep 17 00:00:00 2001 From: Tianzhang Cai <13818704679@163.com> Date: Mon, 10 Apr 2023 02:37:11 +0100 Subject: [PATCH 70/95] Update test_goggle.py --- tests/plugins/generic/test_goggle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 9b194ae0..1188b20d 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -19,6 +19,7 @@ plugin_args = { "n_iter": 10, "device": "cpu", + "sampling_patience": 50 } if not is_missing_goggle_deps: From 0cb9f25fce1722839a48147d19d37164f9c0c117 Mon Sep 17 00:00:00 2001 From: Tianzhang Cai <13818704679@163.com> Date: Mon, 10 Apr 2023 02:39:29 +0100 Subject: [PATCH 71/95] Update tabular_encoder.py --- src/synthcity/plugins/core/models/tabular_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index fe82aca4..8c4d06ea 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -286,7 +286,7 @@ def activation_layout( d = 0 d += 1 out.append((acts[ct], d)) - log.info(out) + log.critical(out) return out From 42c69413db0be0c42e851edbe8a677efbe75ea1d Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 10 Apr 2023 08:52:31 +0200 Subject: [PATCH 72/95] update --- src/synthcity/plugins/core/models/tabular_encoder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 8c4d06ea..54a3d478 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -275,7 +275,6 @@ def activation_layout( """ out = [] acts = dict(discrete=discrete_activation, continuous=continuous_activation) - # NOTE: be careful with the dim of softmax! for column_transform_info in self._column_transform_info_list: ct = column_transform_info.trans_feature_types[0] d = 0 From d7d966d63ad35777dce5a798c22a48050b2592e0 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 10 Apr 2023 09:11:21 +0200 Subject: [PATCH 73/95] update tutorial 8 --- ...utorial8_hyperparameter_optimization.ipynb | 10914 +--------------- 1 file changed, 34 insertions(+), 10880 deletions(-) diff --git a/tutorials/tutorial8_hyperparameter_optimization.ipynb b/tutorials/tutorial8_hyperparameter_optimization.ipynb index f95c1e0c..971dd38d 100644 --- a/tutorials/tutorial8_hyperparameter_optimization.ipynb +++ b/tutorials/tutorial8_hyperparameter_optimization.ipynb @@ -12,24 +12,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[KeOps] Warning : \n", - " The default C++ compiler could not be found on your system.\n", - " You need to either define the CXX environment variable or a symlink to the g++ command.\n", - " For example if g++-8 is the command you can do\n", - " import os\n", - " os.environ['CXX'] = 'g++-8'\n", - " \n", - "[KeOps] Warning : Cuda libraries were not detected on the system ; using cpu only mode\n" - ] - } - ], + "outputs": [], "source": [ "# stdlib\n", "import sys\n", @@ -58,238 +43,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
agesexbmibps1s2s3s4s5s6target
00.0380760.0506800.0616960.021872-0.044223-0.034821-0.043401-0.0025920.019907-0.017646151.0
1-0.001882-0.044642-0.051474-0.026328-0.008449-0.0191630.074412-0.039493-0.068332-0.09220475.0
20.0852990.0506800.044451-0.005670-0.045599-0.034194-0.032356-0.0025920.002861-0.025930141.0
3-0.089063-0.044642-0.011595-0.0366560.0121910.024991-0.0360380.0343090.022688-0.009362206.0
40.005383-0.044642-0.0363850.0218720.0039350.0155960.008142-0.002592-0.031988-0.046641135.0
....................................
4370.0417080.0506800.0196620.059744-0.005697-0.002566-0.028674-0.0025920.0311930.007207178.0
438-0.0055150.050680-0.015906-0.0676420.0493410.079165-0.0286740.034309-0.0181140.044485104.0
4390.0417080.050680-0.0159060.017293-0.037344-0.013840-0.024993-0.011080-0.0468830.015491132.0
440-0.045472-0.0446420.0390620.0012150.0163180.015283-0.0286740.0265600.044529-0.025930220.0
441-0.045472-0.044642-0.073030-0.0814130.0837400.0278090.173816-0.039493-0.0042220.00306457.0
\n", - "

442 rows × 11 columns

\n", - "
" - ], - "text/plain": [ - " age sex bmi bp s1 s2 s3 \\\n", - "0 0.038076 0.050680 0.061696 0.021872 -0.044223 -0.034821 -0.043401 \n", - "1 -0.001882 -0.044642 -0.051474 -0.026328 -0.008449 -0.019163 0.074412 \n", - "2 0.085299 0.050680 0.044451 -0.005670 -0.045599 -0.034194 -0.032356 \n", - "3 -0.089063 -0.044642 -0.011595 -0.036656 0.012191 0.024991 -0.036038 \n", - "4 0.005383 -0.044642 -0.036385 0.021872 0.003935 0.015596 0.008142 \n", - ".. ... ... ... ... ... ... ... \n", - "437 0.041708 0.050680 0.019662 0.059744 -0.005697 -0.002566 -0.028674 \n", - "438 -0.005515 0.050680 -0.015906 -0.067642 0.049341 0.079165 -0.028674 \n", - "439 0.041708 0.050680 -0.015906 0.017293 -0.037344 -0.013840 -0.024993 \n", - "440 -0.045472 -0.044642 0.039062 0.001215 0.016318 0.015283 -0.028674 \n", - "441 -0.045472 -0.044642 -0.073030 -0.081413 0.083740 0.027809 0.173816 \n", - "\n", - " s4 s5 s6 target \n", - "0 -0.002592 0.019907 -0.017646 151.0 \n", - "1 -0.039493 -0.068332 -0.092204 75.0 \n", - "2 -0.002592 0.002861 -0.025930 141.0 \n", - "3 0.034309 0.022688 -0.009362 206.0 \n", - "4 -0.002592 -0.031988 -0.046641 135.0 \n", - ".. ... ... ... ... \n", - "437 -0.002592 0.031193 0.007207 178.0 \n", - "438 0.034309 -0.018114 0.044485 104.0 \n", - "439 -0.011080 -0.046883 0.015491 132.0 \n", - "440 0.026560 0.044529 -0.025930 220.0 \n", - "441 -0.039493 -0.004222 0.003064 57.0 \n", - "\n", - "[442 rows x 11 columns]" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "X, y = load_diabetes(return_X_y=True, as_frame=True)\n", "X[\"target\"] = y\n", @@ -298,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -320,30 +76,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2023-04-08T20:56:15.722354+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-08T20:56:15.722354+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n" - ] - }, - { - "data": { - "text/plain": [ - "synthcity.plugins.generic.plugin_nflow.NormalizingFlowsPlugin" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "PLUGIN = \"nflow\"\n", + "PLUGIN = \"tvae\"\n", "plugin_cls = type(Plugins().get(PLUGIN))\n", "plugin_cls" ] @@ -358,28 +95,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[IntegerDistribution(name='n_iter', data=None, random_state=0, marginal_distribution=None, low=100, high=5000, step=100),\n", - " IntegerDistribution(name='n_layers_hidden', data=None, random_state=0, marginal_distribution=None, low=1, high=10, step=1),\n", - " IntegerDistribution(name='n_units_hidden', data=None, random_state=0, marginal_distribution=None, low=10, high=100, step=1),\n", - " CategoricalDistribution(name='batch_size', data=None, random_state=0, marginal_distribution=None, choices=[32, 64, 128, 256, 512]),\n", - " FloatDistribution(name='dropout', data=None, random_state=0, marginal_distribution=None, low=0.0, high=0.2),\n", - " CategoricalDistribution(name='batch_norm', data=None, random_state=0, marginal_distribution=None, choices=[False, True]),\n", - " CategoricalDistribution(name='lr', data=None, random_state=0, marginal_distribution=None, choices=[0.0001, 0.0002, 0.001]),\n", - " CategoricalDistribution(name='linear_transform_type', data=None, random_state=0, marginal_distribution=None, choices=['lu', 'permutation', 'svd']),\n", - " CategoricalDistribution(name='base_transform_type', data=None, random_state=0, marginal_distribution=None, choices=['affine-autoregressive', 'affine-coupling', 'quadratic-autoregressive', 'quadratic-coupling', 'rq-autoregressive', 'rq-coupling'])]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "plugin_cls.hyperparameter_space()" ] @@ -394,28 +112,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'n_iter': 100,\n", - " 'n_layers_hidden': 1,\n", - " 'n_units_hidden': 87,\n", - " 'batch_size': 256,\n", - " 'dropout': 0.15424246144819787,\n", - " 'batch_norm': False,\n", - " 'lr': 0.001,\n", - " 'linear_transform_type': 'svd',\n", - " 'base_transform_type': 'rq-autoregressive'}" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from synthcity.utils.optuna_sample import suggest_all\n", "\n", @@ -435,82 +134,9 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 100/100 [00:38<00:00, 2.56it/s]\n", - "[2023-04-08T20:57:54.561757+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:30<00:00, 3.24it/s]\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
minmaxmeanstddevmedianiqrroundserrorsdurationsdirection
detection.detection_mlp.mean0.50.50.50.00.50.0105.95minimize
\n", - "
" - ], - "text/plain": [ - " min max mean stddev median iqr rounds \\\n", - "detection.detection_mlp.mean 0.5 0.5 0.5 0.0 0.5 0.0 1 \n", - "\n", - " errors durations direction \n", - "detection.detection_mlp.mean 0 5.95 minimize " - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from synthcity.benchmark import Benchmarks\n", "\n", @@ -534,106 +160,9 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2023-04-08T21:26:16.278633+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-08T21:26:16.301778+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:39<00:00, 2.56it/s]\n", - "[2023-04-08T21:26:59.665597+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-08T21:26:59.684496+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:26<00:00, 3.74it/s]\n", - "[2023-04-08T21:27:30.475951+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-08T21:27:30.495645+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:04<00:00, 21.72it/s]\n", - "[2023-04-08T21:27:39.102805+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-08T21:27:39.117858+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [03:07<00:00, 1.88s/it]\n", - "[2023-04-08T21:31:35.546758+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "[2023-04-08T21:31:35.638203+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - " 0%| | 0/100 [00:00", - "n_iter (IntDistribution): 0.0", - "n_layers_hidden (IntDistribution): 7.256702823093135e-31", - "batch_norm (CategoricalDistribution): 0.02499999999999996", - "n_units_hidden (IntDistribution): 0.025000000000000088", - "batch_size (CategoricalDistribution): 0.04999999999999992", - "linear_transform_type (CategoricalDistribution): 0.09999999999999984", - "dropout (FloatDistribution): 0.275000000000001", - "base_transform_type (CategoricalDistribution): 0.5249999999999992" - ], - "marker": { - "color": "rgb(66,146,198)" - }, - "orientation": "h", - "text": [ - "<0.01", - "<0.01", - "<0.01", - "0.02", - "0.03", - "0.05", - "0.10", - "0.28", - "0.52" - ], - "textposition": "outside", - "type": "bar", - "x": [ - 0, - 0, - 7.256702823093135e-31, - 0.02499999999999996, - 0.025000000000000088, - 0.04999999999999992, - 0.09999999999999984, - 0.275000000000001, - 0.5249999999999992 - ], - "y": [ - "lr", - "n_iter", - "n_layers_hidden", - "batch_norm", - "n_units_hidden", - "batch_size", - "linear_transform_type", - "dropout", - "base_transform_type" - ] - } - ], - "layout": { - "showlegend": false, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Hyperparameter Importances" - }, - "xaxis": { - "title": { - "text": "Importance for Objective Value" - } - }, - "yaxis": { - "title": { - "text": "Hyperparameter" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "# Visualize parameter importances.\n", "plot_param_importances(study)" @@ -9047,908 +256,9 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "cliponaxis": false, - "hovertemplate": [ - "n_iter (IntDistribution): 0.0", - "batch_norm (CategoricalDistribution): 0.0008792721126079252", - "n_units_hidden (IntDistribution): 0.050652923907431195", - "lr (CategoricalDistribution): 0.07515326418808736", - "linear_transform_type (CategoricalDistribution): 0.08234393383908772", - "batch_size (CategoricalDistribution): 0.1506171783782107", - "dropout (FloatDistribution): 0.1928283779305551", - "n_layers_hidden (IntDistribution): 0.20147707299584372", - "base_transform_type (CategoricalDistribution): 0.2460479766481761" - ], - "marker": { - "color": "rgb(66,146,198)" - }, - "orientation": "h", - "text": [ - "<0.01", - "<0.01", - "0.05", - "0.08", - "0.08", - "0.15", - "0.19", - "0.20", - "0.25" - ], - "textposition": "outside", - "type": "bar", - "x": [ - 0, - 0.0008792721126079252, - 0.050652923907431195, - 0.07515326418808736, - 0.08234393383908772, - 0.1506171783782107, - 0.1928283779305551, - 0.20147707299584372, - 0.2460479766481761 - ], - "y": [ - "n_iter", - "batch_norm", - "n_units_hidden", - "lr", - "linear_transform_type", - "batch_size", - "dropout", - "n_layers_hidden", - "base_transform_type" - ] - } - ], - "layout": { - "showlegend": false, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Hyperparameter Importances" - }, - "xaxis": { - "title": { - "text": "Importance for duration" - } - }, - "yaxis": { - "title": { - "text": "Hyperparameter" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "# Learn which hyperparameters are affecting the trial duration with hyperparameter importance.\n", "optuna.visualization.plot_param_importances(\n", @@ -9958,1067 +268,9 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "mode": "lines", - "name": "no-name-13d7c589-c089-4e41-befb-9c6abf683ae9", - "type": "scatter", - "x": [ - 0.4788647342995169, - 0.47907822183184506, - 0.47929170936417315, - 0.4795051968965013, - 0.47971868442882937, - 0.4799321719611575, - 0.4801456594934856, - 0.4803591470258137, - 0.4805726345581418, - 0.48078612209046995, - 0.48099960962279803, - 0.48121309715512617, - 0.48142658468745425, - 0.4816400722197824, - 0.48185355975211047, - 0.4820670472844386, - 0.4822805348167667, - 0.48249402234909483, - 0.4827075098814229, - 0.48292099741375105, - 0.48313448494607913, - 0.4833479724784073, - 0.4835614600107354, - 0.4837749475430635, - 0.48398843507539163, - 0.4842019226077197, - 0.48441541014004785, - 0.48462889767237594, - 0.4848423852047041, - 0.48505587273703216, - 0.4852693602693603, - 0.4854828478016884, - 0.4856963353340165, - 0.4859098228663446, - 0.48612331039867274, - 0.4863367979310008, - 0.48655028546332896, - 0.48676377299565704, - 0.4869772605279852, - 0.48719074806031326, - 0.4874042355926414, - 0.48761772312496954, - 0.4878312106572976, - 0.48804469818962576, - 0.48825818572195384, - 0.488471673254282, - 0.48868516078661006, - 0.4888986483189382, - 0.4891121358512663, - 0.4893256233835944, - 0.4895391109159225, - 0.48975259844825064, - 0.4899660859805787, - 0.49017957351290686, - 0.49039306104523495, - 0.4906065485775631, - 0.49082003610989117, - 0.4910335236422193, - 0.4912470111745474, - 0.4914604987068755, - 0.49167398623920366, - 0.49188747377153175, - 0.4921009613038599, - 0.49231444883618797, - 0.4925279363685161, - 0.4927414239008442, - 0.4929549114331723, - 0.4931683989655004, - 0.49338188649782855, - 0.49359537403015663, - 0.49380886156248477, - 0.49402234909481285, - 0.494235836627141, - 0.4944493241594691, - 0.4946628116917972, - 0.4948762992241253, - 0.49508978675645343, - 0.4953032742887815, - 0.49551676182110965, - 0.4957302493534378, - 0.4959437368857659, - 0.496157224418094, - 0.4963707119504221, - 0.49658419948275023, - 0.4967976870150783, - 0.49701117454740645, - 0.49722466207973454, - 0.4974381496120627, - 0.49765163714439076, - 0.4978651246767189, - 0.498078612209047, - 0.4982920997413751, - 0.4985055872737032, - 0.49871907480603134, - 0.4989325623383594, - 0.49914604987068756, - 0.49935953740301564, - 0.4995730249353438, - 0.4997865124676719, - 0.5 - ], - "y": [ - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 0.14285714285714285, - 1 - ] - } - ], - "layout": { - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Empirical Distribution Function Plot" - }, - "xaxis": { - "title": { - "text": "Objective Value" - } - }, - "yaxis": { - "range": [ - 0, - 1 - ], - "title": { - "text": "Cumulative Probability" - } - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "# Visualize empirical distribution function of the objective.\n", "plot_edf(study)" @@ -11034,107 +286,9 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2023-04-08T21:36:39.947037+0200][28420][CRITICAL] module disabled: D:\\Personal\\Work\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n", - "100%|██████████| 100/100 [00:20<00:00, 4.87it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\u001b[4m\u001b[1mPlugin : test\u001b[0m\u001b[0m\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
minmaxmeanstddevmedianiqrroundserrorsdurations
detection.detection_xgb.mean0.9885060.9885060.9885060.00.9885060.0100.18
detection.detection_mlp.mean0.7036400.7036400.7036400.00.7036400.0103.22
\n", - "
" - ], - "text/plain": [ - " min max mean stddev median \\\n", - "detection.detection_xgb.mean 0.988506 0.988506 0.988506 0.0 0.988506 \n", - "detection.detection_mlp.mean 0.703640 0.703640 0.703640 0.0 0.703640 \n", - "\n", - " iqr rounds errors durations \n", - "detection.detection_xgb.mean 0.0 1 0 0.18 \n", - "detection.detection_mlp.mean 0.0 1 0 3.22 " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], + "outputs": [], "source": [ "best_params = study.best_params\n", "report = Benchmarks.evaluate(\n", From e20e5813559d57e0be462a3ff5da56264abc77b7 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 10 Apr 2023 09:15:59 +0200 Subject: [PATCH 74/95] update --- tests/plugins/generic/test_goggle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 1188b20d..840f7c9c 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -19,7 +19,7 @@ plugin_args = { "n_iter": 10, "device": "cpu", - "sampling_patience": 50 + "sampling_patience": 50, } if not is_missing_goggle_deps: From 472ad523a7d26e8c742797a533d54535cc82961f Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 10 Apr 2023 11:08:32 +0200 Subject: [PATCH 75/95] default cat nonlin of goggle <- gumbel_softmax --- src/synthcity/plugins/core/models/factory.py | 14 +++++++------- .../plugins/core/models/feature_encoder.py | 6 +++++- .../plugins/core/models/tabular_goggle.py | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/synthcity/plugins/core/models/factory.py b/src/synthcity/plugins/core/models/factory.py index a34897c3..6f8f5e0c 100644 --- a/src/synthcity/plugins/core/models/factory.py +++ b/src/synthcity/plugins/core/models/factory.py @@ -30,13 +30,13 @@ MODELS = dict( mlp=".mlp.MLP", - # attention models - transformer=".transformer.TransformerModel", - tabnet=".tabnet.TabNet", # rnn models rnn=nn.RNN, gru=nn.GRU, lstm=nn.LSTM, + # attention models + transformer=".transformer.TransformerModel", + tabnet=".tabnet.TabNet", # time series models inceptiontime=InceptionTime, omniscalecnn=OmniScaleCNN, @@ -85,7 +85,7 @@ ) -def _factory(type_: Union[str, type], params: dict, registry: dict) -> Any: +def _get(type_: Union[str, type], params: dict, registry: dict) -> Any: if isinstance(type_, type): return type_(**params) type_ = type_.lower().replace("_", "") @@ -120,13 +120,13 @@ def get_model(block: Union[str, type], params: dict) -> Any: - transformer - tabnet """ - return _factory(block, params, MODELS) + return _get(block, params, MODELS) @validate_arguments(config=dict(arbitrary_types_allowed=True)) def get_nonlin(nonlin: Union[str, nn.Module], params: dict = {}) -> Any: """Get a nonlinearity layer from a name or a class.""" - return _factory(nonlin, params, ACTIVATIONS) + return _get(nonlin, params, ACTIVATIONS) @validate_arguments(config=dict(arbitrary_types_allowed=True)) @@ -146,4 +146,4 @@ def get_feature_encoder(encoder: Union[str, type], params: dict = {}) -> Any: """ if isinstance(encoder, type): # custom encoder encoder = FeatureEncoder.wraps(encoder) - return _factory(encoder, params, FEATURE_ENCODERS) + return _get(encoder, params, FEATURE_ENCODERS) diff --git a/src/synthcity/plugins/core/models/feature_encoder.py b/src/synthcity/plugins/core/models/feature_encoder.py index 70807e31..25c98c01 100644 --- a/src/synthcity/plugins/core/models/feature_encoder.py +++ b/src/synthcity/plugins/core/models/feature_encoder.py @@ -11,6 +11,7 @@ LabelEncoder, MinMaxScaler, OneHotEncoder, + OrdinalEncoder, QuantileTransformer, RobustScaler, StandardScaler, @@ -150,7 +151,10 @@ def get_feature_names_out(self) -> List[str]: return WrappedEncoder -OneHotEncoder = FeatureEncoder.wraps(OneHotEncoder, categorical=True) +OneHotEncoder = FeatureEncoder.wraps( + OneHotEncoder, categorical=True, handle_unknown="ignore" +) +OrdinalEncoder = FeatureEncoder.wraps(OrdinalEncoder, categorical=True) LabelEncoder = FeatureEncoder.wraps(LabelEncoder, n_dim_out=1, categorical=True) StandardScaler = FeatureEncoder.wraps(StandardScaler) MinMaxScaler = FeatureEncoder.wraps(MinMaxScaler) diff --git a/src/synthcity/plugins/core/models/tabular_goggle.py b/src/synthcity/plugins/core/models/tabular_goggle.py index 84a1051b..478109b7 100644 --- a/src/synthcity/plugins/core/models/tabular_goggle.py +++ b/src/synthcity/plugins/core/models/tabular_goggle.py @@ -43,7 +43,7 @@ def __init__( decoder_nonlin: str = "relu", encoder_max_clusters: int = 20, encoder_whitelist: list = [], - decoder_nonlin_out_discrete: str = "softmax", + decoder_nonlin_out_discrete: str = "gumbel_softmax", decoder_nonlin_out_continuous: str = "tanh", random_state: int = 0, ): @@ -107,7 +107,7 @@ def __init__( The max number of clusters to create for continuous columns when encoding with TabularEncoder. Defaults to 20. encoder_whitelist: list = [] Ignore columns from encoding with TabularEncoder. Defaults to []. - decoder_nonlin_out_discrete: str = "softmax" + decoder_nonlin_out_discrete: str = "gumbel_softmax" Activation function for discrete columns. Useful with the TabularEncoder. Defaults to "softmax". decoder_nonlin_out_continuous: str = "tanh Activation function for continuous columns. Useful with the TabularEncoder.. Defaults to "tanh". From 5dbe66685fbb6d89cb7c653dce5f595c69d2f603 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 10 Apr 2023 11:33:27 +0200 Subject: [PATCH 76/95] get_nonlin('softmax') <- GumbelSoftmax() --- src/synthcity/plugins/core/models/factory.py | 4 ++-- src/synthcity/plugins/core/models/tabular_goggle.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/synthcity/plugins/core/models/factory.py b/src/synthcity/plugins/core/models/factory.py index 6f8f5e0c..c4fb7841 100644 --- a/src/synthcity/plugins/core/models/factory.py +++ b/src/synthcity/plugins/core/models/factory.py @@ -56,8 +56,8 @@ selu=nn.SELU, tanh=nn.Tanh, sigmoid=nn.Sigmoid, - softmax=nn.Softmax, - gumbelsoftmax=GumbelSoftmax, + softmax=GumbelSoftmax, + vanilla_softmax=nn.Softmax, gelu=nn.GELU, silu=nn.SiLU, swish=nn.SiLU, diff --git a/src/synthcity/plugins/core/models/tabular_goggle.py b/src/synthcity/plugins/core/models/tabular_goggle.py index 478109b7..84a1051b 100644 --- a/src/synthcity/plugins/core/models/tabular_goggle.py +++ b/src/synthcity/plugins/core/models/tabular_goggle.py @@ -43,7 +43,7 @@ def __init__( decoder_nonlin: str = "relu", encoder_max_clusters: int = 20, encoder_whitelist: list = [], - decoder_nonlin_out_discrete: str = "gumbel_softmax", + decoder_nonlin_out_discrete: str = "softmax", decoder_nonlin_out_continuous: str = "tanh", random_state: int = 0, ): @@ -107,7 +107,7 @@ def __init__( The max number of clusters to create for continuous columns when encoding with TabularEncoder. Defaults to 20. encoder_whitelist: list = [] Ignore columns from encoding with TabularEncoder. Defaults to []. - decoder_nonlin_out_discrete: str = "gumbel_softmax" + decoder_nonlin_out_discrete: str = "softmax" Activation function for discrete columns. Useful with the TabularEncoder. Defaults to "softmax". decoder_nonlin_out_continuous: str = "tanh Activation function for continuous columns. Useful with the TabularEncoder.. Defaults to "tanh". From 74e897ba8b3b3c9aa5fe0a0e68a1b255274ec51e Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 10 Apr 2023 13:14:01 +0200 Subject: [PATCH 77/95] remove debug logging --- src/synthcity/plugins/core/constraints.py | 2 -- src/synthcity/plugins/core/models/tabular_encoder.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/synthcity/plugins/core/constraints.py b/src/synthcity/plugins/core/constraints.py index be1b2cc8..7660473f 100644 --- a/src/synthcity/plugins/core/constraints.py +++ b/src/synthcity/plugins/core/constraints.py @@ -167,8 +167,6 @@ def filter(self, X: pd.DataFrame) -> pd.DataFrame: log.critical( f"[{feature}] quality loss for constraints {op} = {thresh}. Remaining {res.sum()}. prev length {prev}. Original dtype {X[feature].dtype}.", ) - if res.sum() < 5: - log.critical(str(X[~res])) return res @validate_arguments(config=dict(arbitrary_types_allowed=True)) diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 54a3d478..45b13f50 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -285,7 +285,6 @@ def activation_layout( d = 0 d += 1 out.append((acts[ct], d)) - log.critical(out) return out From 27553e958c5d2e4a73f042628cc50047a374ea3e Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 10 Apr 2023 13:15:34 +0200 Subject: [PATCH 78/95] update --- src/synthcity/plugins/core/constraints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/synthcity/plugins/core/constraints.py b/src/synthcity/plugins/core/constraints.py index 7660473f..dc79e56b 100644 --- a/src/synthcity/plugins/core/constraints.py +++ b/src/synthcity/plugins/core/constraints.py @@ -164,7 +164,7 @@ def filter(self, X: pd.DataFrame) -> pd.DataFrame: thresh, ) if res.sum() < prev: - log.critical( + log.info( f"[{feature}] quality loss for constraints {op} = {thresh}. Remaining {res.sum()}. prev length {prev}. Original dtype {X[feature].dtype}.", ) return res From 7fc5ce4bb147281e79f8eaad150ffd620dfc4216 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 10 Apr 2023 13:19:05 +0200 Subject: [PATCH 79/95] update --- tests/plugins/generic/test_goggle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 840f7c9c..9b194ae0 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -19,7 +19,6 @@ plugin_args = { "n_iter": 10, "device": "cpu", - "sampling_patience": 50, } if not is_missing_goggle_deps: From b8c952253c8c9d5a5835ae6def0f8fecb10a4fcc Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 18 Apr 2023 12:01:37 +0200 Subject: [PATCH 80/95] fix merge --- src/synthcity/plugins/core/distribution.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/synthcity/plugins/core/distribution.py b/src/synthcity/plugins/core/distribution.py index b8df0355..f53ee563 100644 --- a/src/synthcity/plugins/core/distribution.py +++ b/src/synthcity/plugins/core/distribution.py @@ -379,7 +379,6 @@ class DatetimeDistribution(Distribution): :parts: 1 """ - offset: int = 120 low: datetime = datetime.utcfromtimestamp(0) high: datetime = datetime.now() step: timedelta = timedelta(microseconds=1) From ecc9d086232188d1ded0aa9f7168e6cbacb6efdb Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Tue, 18 Apr 2023 12:05:03 +0200 Subject: [PATCH 81/95] fix merge --- src/synthcity/plugins/core/models/tabnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/synthcity/plugins/core/models/tabnet.py b/src/synthcity/plugins/core/models/tabnet.py index 0a57c3da..a9a04fe3 100644 --- a/src/synthcity/plugins/core/models/tabnet.py +++ b/src/synthcity/plugins/core/models/tabnet.py @@ -1,7 +1,7 @@ -TabNet: Attentive Interpretable Tabular Learning -Reference: -- https://arxiv.org/pdf/1908.07442.pdf -- https://github.com/dreamquark-ai/tabnet +# TabNet: Attentive Interpretable Tabular Learning +# Reference: +# - https://arxiv.org/pdf/1908.07442.pdf +# - https://github.com/dreamquark-ai/tabnet # stdlib from typing import List, Optional, Tuple From c2775bac31f9fe904b2f90f00a4576b5f8e42ae6 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 00:29:39 +0200 Subject: [PATCH 82/95] update pip upgrade commands in workflows --- .github/workflows/test_full.yml | 2 +- .github/workflows/test_pr.yml | 2 +- .github/workflows/test_tutorials.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index e3265692..1b497579 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -27,8 +27,8 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -r prereq.txt - pip install --upgrade pip - name: Test Core run: | pip install .[testing] diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 2c907433..4babfca1 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -54,8 +54,8 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -r prereq.txt - pip install --upgrade pip - name: Test Core run: | pip install .[testing] diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index 212e79e2..bcbb0e1f 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -32,8 +32,8 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -r prereq.txt - pip install --upgrade pip pip install .[all] From 1d9c7a4fb06abd39674c06e8e8ca10f09bd72a0f Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 00:33:31 +0200 Subject: [PATCH 83/95] update pip upgrade commands in workflows --- .github/workflows/test_full.yml | 2 +- .github/workflows/test_pr.yml | 2 +- .github/workflows/test_tutorials.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index e3265692..1b497579 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -27,8 +27,8 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -r prereq.txt - pip install --upgrade pip - name: Test Core run: | pip install .[testing] diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 2c907433..4babfca1 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -54,8 +54,8 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -r prereq.txt - pip install --upgrade pip - name: Test Core run: | pip install .[testing] diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index 212e79e2..bcbb0e1f 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -32,8 +32,8 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -r prereq.txt - pip install --upgrade pip pip install .[all] From 385d2edaed7a3c50ca385e7ef0b96053a2bf2b81 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 00:52:17 +0200 Subject: [PATCH 84/95] keep pip version to 23.0.1 in workflows --- .github/workflows/test_full.yml | 2 +- .github/workflows/test_pr.yml | 2 +- .github/workflows/test_tutorials.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index 1b497579..e258db95 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -27,7 +27,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install --upgrade pip + pip install pip==23.0.1 pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 4babfca1..fec1126d 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -54,7 +54,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install --upgrade pip + pip install pip==23.0.1 pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index bcbb0e1f..69195a09 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -32,7 +32,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install --upgrade pip + pip install pip==23.0.1 pip install -r prereq.txt pip install .[all] From 81fb12b973c410a3c99e21a5023c13e61dea17b6 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 00:57:58 +0200 Subject: [PATCH 85/95] keep pip version to 23.0.1 in workflows --- .github/workflows/test_full.yml | 2 +- .github/workflows/test_pr.yml | 2 +- .github/workflows/test_tutorials.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index 1b497579..e258db95 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -27,7 +27,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install --upgrade pip + pip install pip==23.0.1 pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 4babfca1..fec1126d 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -54,7 +54,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install --upgrade pip + pip install pip==23.0.1 pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index bcbb0e1f..69195a09 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -32,7 +32,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install --upgrade pip + pip install pip==23.0.1 pip install -r prereq.txt pip install .[all] From 3884fc439a7e5f276aa90f69b3b3e50712125e46 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 11:58:16 +0200 Subject: [PATCH 86/95] update --- .github/workflows/test_full.yml | 2 +- .github/workflows/test_pr.yml | 2 +- .github/workflows/test_tutorials.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index e258db95..b34df21b 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -27,7 +27,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - pip install pip==23.0.1 + python -m pip install pip==23.0.1 pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index fec1126d..527e4c36 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -54,7 +54,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - pip install pip==23.0.1 + python -m pip install pip==23.0.1 pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index 69195a09..8a29b3c4 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -32,7 +32,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - pip install pip==23.0.1 + python -m pip install pip==23.0.1 pip install -r prereq.txt pip install .[all] From 7640f355b5a28611c5ff7c062a64451b917ade81 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 12:18:51 +0200 Subject: [PATCH 87/95] update --- .github/workflows/test_full.yml | 2 +- .github/workflows/test_pr.yml | 2 +- .github/workflows/test_tutorials.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index b34df21b..0f977ba1 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -27,7 +27,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install pip==23.0.1 + python -m pip install -U pip pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 527e4c36..bac8905e 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -54,7 +54,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install pip==23.0.1 + python -m pip install -U pip pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index 8a29b3c4..c93a0c35 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -32,7 +32,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install pip==23.0.1 + python -m pip install -U pip pip install -r prereq.txt pip install .[all] From c91246b8807bb4c51103f292bf255fe1dfa3f416 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 12:27:02 +0200 Subject: [PATCH 88/95] update --- prereq.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prereq.txt b/prereq.txt index 0d7eb1f0..ec51b0a3 100644 --- a/prereq.txt +++ b/prereq.txt @@ -1,3 +1,3 @@ numpy -torch<2.0 +torch>=1.10.0,<2.0 tsai From 38fc7966ad5381b8ad6136ae10fb90b2dcf7e906 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 12:32:26 +0200 Subject: [PATCH 89/95] update --- .github/workflows/test_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index bac8905e..66bd6cc3 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -54,7 +54,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - python -m pip install -U pip + # python -m pip install -U pip pip install -r prereq.txt - name: Test Core run: | From 899a9d83605ce8fb677e5ee3d82f7ee0163077d3 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 12:44:00 +0200 Subject: [PATCH 90/95] update --- .github/workflows/test_pr.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 66bd6cc3..e98fa172 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -54,7 +54,8 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - # python -m pip install -U pip + python -m pip install -U pip + pip install --upgrade setuptools, wheel pip install -r prereq.txt - name: Test Core run: | From 60fa08da1b5e5e6ec74093a2bb6adcc33cab0534 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 12:46:53 +0200 Subject: [PATCH 91/95] update --- .github/workflows/test_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index e98fa172..40cc83df 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -55,7 +55,7 @@ jobs: - name: Install dependencies run: | python -m pip install -U pip - pip install --upgrade setuptools, wheel + pip install --upgrade setuptools wheel pip install -r prereq.txt - name: Test Core run: | From 50a77c52eae6d6dc5f65b40b27b5393fa91388f4 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 13:00:42 +0200 Subject: [PATCH 92/95] fix distribution --- src/synthcity/plugins/core/distribution.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/synthcity/plugins/core/distribution.py b/src/synthcity/plugins/core/distribution.py index f53ee563..788db4e0 100644 --- a/src/synthcity/plugins/core/distribution.py +++ b/src/synthcity/plugins/core/distribution.py @@ -384,12 +384,6 @@ class DatetimeDistribution(Distribution): step: timedelta = timedelta(microseconds=1) offset: timedelta = timedelta(seconds=120) - @validator("offset", always=True) - def _validate_offset(cls: Any, v: int) -> int: - if v < 0: - raise ValueError("offset must be greater than 0") - return v - @validator("low", always=True) def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime: mkey = "marginal_distribution" From 727662f1f5508aca90f413dc019f9a041cfda0ed Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Thu, 20 Apr 2023 17:05:26 +0200 Subject: [PATCH 93/95] update --- src/synthcity/plugins/generic/plugin_ddpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index 5d9f0ac0..8789d7d6 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -183,8 +183,8 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]: IntLogDistribution(name="batch_size", low=256, high=4096), IntegerDistribution(name="num_timesteps", low=10, high=1000), IntLogDistribution(name="n_iter", low=1000, high=10000), - IntegerDistribution(name="n_layers_hidden", low=2, high=8), - IntLogDistribution(name="dim_hidden", low=128, high=1024), + # IntegerDistribution(name="n_layers_hidden", low=2, high=8), + # IntLogDistribution(name="dim_hidden", low=128, high=1024), ] def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": From 212d7cb3c9bcbd550eaf78739c0df98c6cb9925b Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 24 Apr 2023 09:19:24 +0200 Subject: [PATCH 94/95] move upgrading of wheel to prereq.txt --- .github/workflows/test_pr.yml | 1 - prereq.txt | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 40cc83df..bac8905e 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -55,7 +55,6 @@ jobs: - name: Install dependencies run: | python -m pip install -U pip - pip install --upgrade setuptools wheel pip install -r prereq.txt - name: Test Core run: | diff --git a/prereq.txt b/prereq.txt index ec51b0a3..125d1b13 100644 --- a/prereq.txt +++ b/prereq.txt @@ -1,3 +1,4 @@ +wheel>=0.40 numpy torch>=1.10.0,<2.0 -tsai +tsai \ No newline at end of file From d8e63c3537f2c8724d23de2615efab6db5d5d034 Mon Sep 17 00:00:00 2001 From: TZCai <13818704679@163.com> Date: Mon, 24 Apr 2023 14:13:02 +0200 Subject: [PATCH 95/95] update --- prereq.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prereq.txt b/prereq.txt index 125d1b13..6554078b 100644 --- a/prereq.txt +++ b/prereq.txt @@ -1,4 +1,4 @@ -wheel>=0.40 numpy torch>=1.10.0,<2.0 -tsai \ No newline at end of file +tsai +wheel>=0.40