Skip to content

Commit

Permalink
Fix bugs, including PBCDataloader incorrect preprocessing (#152)
Browse files Browse the repository at this point in the history
* Fix no convergence XGBTimeSeriesSurvival tests

* Fix length mismatch error when no static data

* Fix PBCDataloader preprocessing

* Update DECAF

---------

Co-authored-by: Bogdan Cebere <[email protected]>
  • Loading branch information
DrShushen and bcebere authored Mar 16, 2023
1 parent cf6ea56 commit 33a0556
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ install_requires =
numpy>=1.20
lifelines>=0.27
opacus>=1.3
decaf-synthetic-data>=0.1.5
decaf-synthetic-data>=0.1.6
optuna>=3.1
shap
tqdm
Expand Down
6 changes: 3 additions & 3 deletions src/synthcity/metrics/eval_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def evaluate(
"n_jobs": 2,
"verbosity": 0,
"depth": 3,
"strategy": "weibull", # "weibull", "debiased_bce"
"strategy": "debiased_bce", # "weibull", "debiased_bce"
"random_state": self._random_state,
},
X_gt,
Expand Down Expand Up @@ -705,7 +705,7 @@ def evaluate(
"n_jobs": 2,
"verbosity": 0,
"depth": 3,
"strategy": "weibull", # "weibull", "debiased_bce"
"strategy": "debiased_bce", # "weibull", "debiased_bce"
"random_state": self._random_state,
},
X_gt,
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def evaluate(
n_jobs=2,
verbosity=0,
depth=3,
strategy="weibull", # "weibull", "debiased_bce"
strategy="debiased_bce", # "weibull", "debiased_bce"
random_state=self._random_state,
)

Expand Down
3 changes: 2 additions & 1 deletion src/synthcity/plugins/core/models/ts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ def __init__(

self.device = device
self.window_size = window_size
self.n_static_units_in = n_static_units_in
self.model = MLP(
task_type="regression",
n_units_in=n_static_units_in + n_temporal_units_in * window_size,
Expand All @@ -719,7 +720,7 @@ def __init__(
def forward(
self, static_data: torch.Tensor, temporal_data: torch.Tensor
) -> torch.Tensor:
if len(static_data) != len(temporal_data):
if self.n_static_units_in > 0 and len(static_data) != len(temporal_data):
raise ValueError("Length mismatch between static and temporal data")

batch_size, seq_len, n_feats = temporal_data.shape
Expand Down
10 changes: 9 additions & 1 deletion src/synthcity/utils/datasets/time_series/pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _load_pbc_dataset(self, sequential: bool = True) -> Tuple:
data = pd.read_csv(df_path)

data["time"] = data["years"] - data["year"]
data = data.sort_values(by="time")
data = data.sort_values(by=["id", "time"], ignore_index=True)
data["histologic"] = data["histologic"].astype(str)
dat_cat = data[
["drug", "sex", "ascites", "hepatomegaly", "spiders", "edema", "histologic"]
Expand Down Expand Up @@ -136,6 +136,14 @@ def _load_pbc_dataset(self, sequential: bool = True) -> Tuple:
patient = x_[data["id"] == id_]

patient_static = patient[static_cols]
if (
not (patient_static.iloc[0] == patient_static).all().all()
): # pragma: no cover
# This is a sanity check.
raise RuntimeError(
"Found patient with static data that are not actually fixed:\n"
f"id: {id_}\n{patient_static}"
)
x_static.append(patient_static.values[0].tolist())

patient_temporal = patient[temporal_cols]
Expand Down

0 comments on commit 33a0556

Please sign in to comment.