Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix full test errors #255

Merged
merged 13 commits into from
Feb 29, 2024
15 changes: 13 additions & 2 deletions .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ jobs:
run: |
python -m pip install -U pip
pip install -r prereq.txt
- name: Test Core
- name: Test Core - slow part one
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50
pytest -vvvs --durations=50 -m "slow_1"
- name: Test Core - slow part two
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "slow_2"
- name: Test Core - fast
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "not slow"
- name: Test GOGGLE
run: |
pip install .[testing,goggle]
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ install_requires =
scikit-learn>=1.2
nflows>=0.14
numpy>=1.20, <1.24
lifelines>=0.27,!= 0.27.5
lifelines>=0.27,!= 0.27.5, <0.27.8
opacus>=1.3
decaf-synthetic-data>=0.1.6
optuna>=3.1
Expand Down Expand Up @@ -117,6 +117,8 @@ testpaths = tests
# Use pytest markers to select/deselect specific tests
markers =
slow: mark tests as slow (deselect with '-m "not slow"')
slow_1: mark tests as slow (deselect with '-m "not slow_1"')
slow_2: mark tests as slow (deselect with '-m "not slow_1"')

[devpi:upload]
# Options for the devpi: PyPI server and packaging tool
Expand Down
3 changes: 2 additions & 1 deletion src/synthcity/plugins/core/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,12 +928,13 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:
self.data["observation_times"],
self.data["outcome"],
)

if as_numpy:
longest_observation_seq = max([len(seq) for seq in temporal_data])
return (
np.asarray(static_data),
np.asarray(
pd.concat(temporal_data)
temporal_data
), # TODO: check this works with time series benchmarks
# masked array to handle variable length sequences
ma.vstack(
Expand Down
2 changes: 0 additions & 2 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,6 @@ class PluginLoader:

@validate_arguments
def __init__(self, plugins: list, expected_type: Type, categories: list) -> None:
# self.reload()
global PLUGIN_CATEGORY_REGISTRY
PLUGIN_CATEGORY_REGISTRY = {cat: [] for cat in categories}
self._refresh()
Expand Down Expand Up @@ -639,7 +638,6 @@ def list(self) -> List[str]:
for plugin in all_plugins:
if self.get_type(plugin).type() in self._categories:
plugins.append(plugin)

return list(set(plugins))

def types(self) -> List[Type]:
Expand Down
2 changes: 2 additions & 0 deletions src/synthcity/plugins/privacy/plugin_dpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class DPGANPlugin(Plugin):
>>>
>>> plugin.generate(50)

Note: There is a known issue with the training step for training GANs with conditionals with dp_enabled set to True, as is the case for DPGAN.

"""

@validate_arguments(config=dict(arbitrary_types_allowed=True))
Expand Down
2 changes: 1 addition & 1 deletion src/synthcity/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.9"
__version__ = "0.2.10"

MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
PATCH_VERSION = __version__.split(".")[-1]
1 change: 1 addition & 0 deletions tests/metrics/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_detect_synth_timeseries(test_plugin: Plugin, evaluator_t: Type) -> None
assert evaluator.direction() == "minimize"


@pytest.mark.slow_1
@pytest.mark.slow
def test_image_support_detection() -> None:
dataset = datasets.MNIST(".", download=True)
Expand Down
6 changes: 6 additions & 0 deletions tests/metrics/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_evaluate_performance_classifier(
@pytest.mark.xfail
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.slow_1
@pytest.mark.slow
def test_evaluate_feature_importance_rank_dist_clf(
distance: str, test_plugin: Plugin
Expand Down Expand Up @@ -183,6 +184,7 @@ def test_evaluate_performance_regression(
@pytest.mark.xfail
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.slow_1
@pytest.mark.slow
def test_evaluate_feature_importance_rank_dist_reg(
distance: str, test_plugin: Plugin
Expand Down Expand Up @@ -211,6 +213,7 @@ def test_evaluate_feature_importance_rank_dist_reg(
assert score["pvalue"] > 0


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("test_plugin", [Plugins().get("marginal_distributions")])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -296,6 +299,7 @@ def test_evaluate_performance_survival_analysis(
@pytest.mark.xfail
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.slow_1
@pytest.mark.slow
def test_evaluate_feature_importance_rank_dist_surv(
distance: str, test_plugin: Plugin
Expand Down Expand Up @@ -362,6 +366,7 @@ def test_evaluate_performance_custom_labels(
assert "syn_ood" in good_score


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("test_plugin", [Plugins().get("timegan")])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -472,6 +477,7 @@ def test_evaluate_performance_time_series_survival(
assert def_score == good_score["syn_id.c_index"] - good_score["syn_id.brier_score"]


@pytest.mark.slow_1
@pytest.mark.slow
def test_image_support_perf() -> None:
dataset = datasets.MNIST(".", download=True)
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/core/models/test_tabular_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_gan_generation_with_early_stopping(patience_metric: Tuple[str, str]) ->
assert generated.shape == (10, X.shape[1])


@pytest.mark.slow_1
@pytest.mark.slow
def test_gan_sampling_adjustment() -> None:
X = get_airfoil_dataset()
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/core/models/test_ts_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_ts_gan_generation(source: Any) -> None:
assert observation_times_gen.shape == (10, temporal.shape[1])


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [GoogleStocksDataloader])
def test_ts_gan_generation_schema(source: Any) -> None:
Expand Down
3 changes: 3 additions & 0 deletions tests/plugins/core/models/test_ts_tabular_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_network_config() -> None:
assert net.model.embedding_penalty == 2


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [SineDataloader, GoogleStocksDataloader])
def test_ts_gan_generation(source: Any) -> None:
Expand All @@ -86,6 +87,7 @@ def test_ts_gan_generation(source: Any) -> None:
)


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [GoogleStocksDataloader])
def test_ts_gan_generation_schema(source: Any) -> None:
Expand Down Expand Up @@ -118,6 +120,7 @@ def test_ts_gan_generation_schema(source: Any) -> None:
assert reference_schema.as_constraints().filter(seq_df).sum() > 0


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [SineDataloader, GoogleStocksDataloader])
def test_ts_tabular_gan_conditional(source: Any) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/core/models/test_ts_tabular_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_ts_vae_generation(source: Any) -> None:
)


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [GoogleStocksDataloader])
def test_ts_vae_generation_schema(source: Any) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_train_prediction_coxph(rnn_type: str, output_type: str) -> None:
assert score["clf"]["c_index"][0] > 0.5


@pytest.mark.slow_1
@pytest.mark.slow
def test_hyperparam_search() -> None:
static, temporal, observation_times, outcome = PBCDataloader(as_numpy=True).load()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_train_prediction_dyn_deephit(rnn_type: str, output_type: str) -> None:
assert score["clf"]["c_index"][0] > 0.5


@pytest.mark.slow_1
@pytest.mark.slow
def test_hyperparam_search() -> None:
static, temporal, observation_times, outcome = PBCDataloader(as_numpy=True).load()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_train_prediction(emb_rnn_type: str) -> None:
assert score["clf"]["c_index"][0] > 0.5


@pytest.mark.slow_1
@pytest.mark.slow
def test_hyperparam_search() -> None:
static, temporal, observation_times, outcome = PBCDataloader(as_numpy=True).load()
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/domain_adaptation/test_radialgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_1
@pytest.mark.slow
def test_eval_performance_radialgan() -> None:
results = []
Expand Down
2 changes: 2 additions & 0 deletions tests/plugins/generic/test_arf.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("compress_dataset", [True, False])
def test_eval_performance_arf(compress_dataset: bool) -> None:
Expand Down Expand Up @@ -151,6 +152,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d
return start + (end - start) * random.random()


@pytest.mark.slow_1
@pytest.mark.slow
def test_plugin_encoding() -> None:
assert plugin is not None
Expand Down
2 changes: 2 additions & 0 deletions tests/plugins/generic/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("compress_dataset", [True, False])
def test_eval_performance_ctgan(compress_dataset: bool) -> None:
Expand Down Expand Up @@ -169,6 +170,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d
return start + (end - start) * random.random()


@pytest.mark.slow_1
@pytest.mark.slow
def test_plugin_encoding() -> None:
data = [[gen_datetime(), i % 2 == 0, i] for i in range(1000)]
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/generic/test_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("compress_dataset", [True, False])
def test_eval_performance_ddpm(compress_dataset: bool) -> None:
Expand Down
41 changes: 18 additions & 23 deletions tests/plugins/generic/test_goggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.datasets import load_diabetes, load_iris

# synthcity absolute
from synthcity.metrics.eval import PerformanceEvaluatorXGB
from synthcity.metrics.eval import AlphaPrecision
from synthcity.plugins import Plugin
from synthcity.plugins.core.constraints import Constraints
from synthcity.plugins.core.dataloader import GenericDataLoader
Expand Down Expand Up @@ -149,39 +149,34 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


# TODO: Known issue goggle seems to have a performance issue.
# Testing fidelity instead. Also need to test more architectures
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.slow_2
@pytest.mark.slow
@pytest.mark.parametrize(
"compress_dataset,decoder_arch",
[
(True, "het"),
(False, "het"),
(True, "gcn"),
(False, "gcn"),
(True, "sage"),
(False, "sage"),
],
)
def test_eval_performance_goggle(compress_dataset: bool, decoder_arch: str) -> None:
def test_eval_fidelity_goggle(compress_dataset: bool, decoder_arch: str) -> None:
results = []

Xraw, y = load_diabetes(return_X_y=True, as_frame=True)
Xraw, y = load_iris(return_X_y=True, as_frame=True)
Xraw["target"] = y
X = GenericDataLoader(Xraw)

assert plugin is not None
for retry in range(2):
for retry in range(3):
test_plugin = plugin(
n_iter=5000,
compress_dataset=compress_dataset,
decoder_arch=decoder_arch,
encoder_dim=32,
encoder_l=4,
decoder_dim=32,
decoder_l=4,
data_encoder_max_clusters=20,
compress_dataset=False,
decoder_arch="gcn",
random_state=retry,
)
evaluator = PerformanceEvaluatorXGB()
evaluator = AlphaPrecision()

test_plugin.fit(X)
X_syn = test_plugin.generate()

results.append(evaluator.evaluate(X, X_syn)["syn_id"])
X_syn = test_plugin.generate(count=len(X), random_state=retry)
eval_results = evaluator.evaluate(X, X_syn)
results.append(eval_results["authenticity_OC"])

assert np.mean(results) > 0.7
2 changes: 2 additions & 0 deletions tests/plugins/generic/test_great.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
assert (X_gen1.numpy() != X_gen3.numpy()).any()


@pytest.mark.slow_2
@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
Expand Down Expand Up @@ -185,6 +186,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d
return start + (end - start) * random.random()


@pytest.mark.slow_2
@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/generic/test_nflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_2
@pytest.mark.slow
@pytest.mark.parametrize("compress_dataset", [True, False])
def test_eval_performance_nflow(compress_dataset: bool) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/generic/test_rtvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_2
@pytest.mark.slow
def test_eval_performance_rtvae() -> None:
results = []
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/generic/test_tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_2
@pytest.mark.slow
def test_eval_performance_tvae() -> None:
results = []
Expand Down
2 changes: 2 additions & 0 deletions tests/plugins/images/test_image_adsgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_plugin_generate() -> None:
assert len(X_gen) == 50


@pytest.mark.slow_2
@pytest.mark.slow
def test_plugin_generate_with_conditional() -> None:
test_plugin = plugin(n_iter=10, n_units_latent=13)
Expand All @@ -71,6 +72,7 @@ def test_plugin_generate_with_conditional() -> None:
assert len(X_gen) == 50


@pytest.mark.slow_2
@pytest.mark.slow
def test_plugin_generate_with_stop_conditional() -> None:
test_plugin = plugin(n_iter=10, n_units_latent=13, n_iter_print=2)
Expand Down
Loading
Loading