diff --git a/setup.cfg b/setup.cfg index 24c79d02..7be378de 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,7 @@ install_requires = numpy>=1.20 lifelines>=0.27 opacus>=1.3 - decaf-synthetic-data>=0.1.5 + decaf-synthetic-data>=0.1.6 optuna>=3.1 shap tqdm diff --git a/src/synthcity/metrics/eval_performance.py b/src/synthcity/metrics/eval_performance.py index fc44cf8d..8a24345e 100644 --- a/src/synthcity/metrics/eval_performance.py +++ b/src/synthcity/metrics/eval_performance.py @@ -675,7 +675,7 @@ def evaluate( "n_jobs": 2, "verbosity": 0, "depth": 3, - "strategy": "weibull", # "weibull", "debiased_bce" + "strategy": "debiased_bce", # "weibull", "debiased_bce" "random_state": self._random_state, }, X_gt, @@ -705,7 +705,7 @@ def evaluate( "n_jobs": 2, "verbosity": 0, "depth": 3, - "strategy": "weibull", # "weibull", "debiased_bce" + "strategy": "debiased_bce", # "weibull", "debiased_bce" "random_state": self._random_state, }, X_gt, @@ -1033,7 +1033,7 @@ def evaluate( n_jobs=2, verbosity=0, depth=3, - strategy="weibull", # "weibull", "debiased_bce" + strategy="debiased_bce", # "weibull", "debiased_bce" random_state=self._random_state, ) diff --git a/src/synthcity/plugins/core/models/ts_model.py b/src/synthcity/plugins/core/models/ts_model.py index aa5ef333..a7026146 100644 --- a/src/synthcity/plugins/core/models/ts_model.py +++ b/src/synthcity/plugins/core/models/ts_model.py @@ -704,6 +704,7 @@ def __init__( self.device = device self.window_size = window_size + self.n_static_units_in = n_static_units_in self.model = MLP( task_type="regression", n_units_in=n_static_units_in + n_temporal_units_in * window_size, @@ -719,7 +720,7 @@ def __init__( def forward( self, static_data: torch.Tensor, temporal_data: torch.Tensor ) -> torch.Tensor: - if len(static_data) != len(temporal_data): + if self.n_static_units_in > 0 and len(static_data) != len(temporal_data): raise ValueError("Length mismatch between static and temporal data") batch_size, seq_len, n_feats = temporal_data.shape diff --git a/src/synthcity/utils/datasets/time_series/pbc.py b/src/synthcity/utils/datasets/time_series/pbc.py index 57c7026f..185b0477 100644 --- a/src/synthcity/utils/datasets/time_series/pbc.py +++ b/src/synthcity/utils/datasets/time_series/pbc.py @@ -62,7 +62,7 @@ def _load_pbc_dataset(self, sequential: bool = True) -> Tuple: data = pd.read_csv(df_path) data["time"] = data["years"] - data["year"] - data = data.sort_values(by="time") + data = data.sort_values(by=["id", "time"], ignore_index=True) data["histologic"] = data["histologic"].astype(str) dat_cat = data[ ["drug", "sex", "ascites", "hepatomegaly", "spiders", "edema", "histologic"] @@ -136,6 +136,14 @@ def _load_pbc_dataset(self, sequential: bool = True) -> Tuple: patient = x_[data["id"] == id_] patient_static = patient[static_cols] + if ( + not (patient_static.iloc[0] == patient_static).all().all() + ): # pragma: no cover + # This is a sanity check. + raise RuntimeError( + "Found patient with static data that are not actually fixed:\n" + f"id: {id_}\n{patient_static}" + ) x_static.append(patient_static.values[0].tolist()) patient_temporal = patient[temporal_cols]