Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skipif goggle tests #170

Merged
12 changes: 8 additions & 4 deletions .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ jobs:
- name: Install dependencies
run: |
pip install -r prereq.txt

pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: pytest -vvvs --durations=50
- name: Test Core
run: |
pip install .[testing]
pytest -vvvs --durations=50
- name: Test GOGGLE
run: |
pip install .[testing,goggle]
pytest -vvvs -k goggle --durations=50
14 changes: 9 additions & 5 deletions .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Tests Python
name: Tests Fast Python

on:
push:
Expand Down Expand Up @@ -55,8 +55,12 @@ jobs:
- name: Install dependencies
run: |
pip install -r prereq.txt

pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: pytest -vvvs -m "not slow" --durations=50
- name: Test Core
run: |
pip install .[testing]
pytest -vvvsx -m "not slow" --durations=50
- name: Test GOGGLE
run: |
pip install .[testing,goggle]
pytest -vvvsx -m "not slow" -k goggle
2 changes: 1 addition & 1 deletion prereq.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy
torch
torch<2.0
tsai
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ python_requires = >=3.7
install_requires =
scikit-learn>=1.0
nflows>=0.14
pandas>=1.3
torch>=1.10.0
pandas>=1.3,<2.0
torch>=1.10,<2.0
numpy>=1.20
lifelines>=0.27
opacus>=1.3
Expand Down
24 changes: 20 additions & 4 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,16 @@ def __init__(self, plugins: list, expected_type: Type, categories: list) -> None
self._available_plugins = {}
for plugin in plugins:
stem = Path(plugin).stem.split("plugin_")[-1]
cls = self._load_single_plugin_impl(plugin)
if cls is None:
continue
self._available_plugins[stem] = plugin
self._expected_type = expected_type
self._categories = categories

@validate_arguments
def _load_single_plugin(self, plugin_name: str) -> None:
"""Helper for loading a single plugin"""
def _load_single_plugin_impl(self, plugin_name: str) -> Optional[Type]:
"""Helper for loading a single plugin implementation"""
plugin = Path(plugin_name)
name = plugin.stem
ptype = plugin.parent.name
Expand All @@ -579,6 +582,10 @@ def _load_single_plugin(self, plugin_name: str) -> None:

spec.loader.exec_module(mod)
cls = mod.plugin
if cls is None:
log.critical(f"module disabled: {plugin_name}")
return None

failed = False
break
except BaseException as e:
Expand All @@ -587,10 +594,19 @@ def _load_single_plugin(self, plugin_name: str) -> None:

if failed:
log.critical(f"module {name} load failed")
return
return None

return cls

@validate_arguments
def _load_single_plugin(self, plugin_name: str) -> bool:
"""Helper for loading a single plugin"""
cls = self._load_single_plugin_impl(plugin_name)
if cls is None:
return False

log.debug(f"Loaded plugin {cls.type()} - {cls.name()}")
self.add(cls.name(), cls)
return True

def list(self) -> List[str]:
"""Get all the available plugins."""
Expand Down
14 changes: 12 additions & 2 deletions src/synthcity/plugins/generic/plugin_goggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,18 @@
FloatDistribution,
IntegerDistribution,
)
from synthcity.plugins.core.models.tabular_goggle import TabularGoggle
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema
from synthcity.utils.constants import DEVICE

try:
# synthcity absolute
from synthcity.plugins.core.models.tabular_goggle import TabularGoggle

module_disabled = False
except ImportError:
module_disabled = True


class GOGGLEPlugin(Plugin):
"""
Expand Down Expand Up @@ -248,4 +255,7 @@ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFra
return self._safe_generate(self.model.generate, count, syn_schema)


plugin = GOGGLEPlugin
if module_disabled:
plugin = None
else:
plugin = GOGGLEPlugin
13 changes: 9 additions & 4 deletions tests/plugins/generic/generic_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# stdlib
from typing import Dict, List, Type
from typing import Dict, List, Optional, Type

# third party
import pandas as pd
Expand All @@ -10,15 +10,20 @@
from synthcity.utils.serialization import load, save


def generate_fixtures(name: str, plugin: Type, plugin_args: Dict = {}) -> List:
def generate_fixtures(
name: str, plugin: Optional[Type], plugin_args: Dict = {}
) -> List:
if plugin is None:
return []

def from_api() -> Plugin:
return Plugins().get(name, **plugin_args)

def from_module() -> Plugin:
return plugin(**plugin_args)
return plugin(**plugin_args) # type: ignore

def from_serde() -> Plugin:
buff = save(plugin(**plugin_args))
buff = save(plugin(**plugin_args)) # type: ignore
return load(buff)

return [from_api(), from_module(), from_serde()]
Expand Down
48 changes: 41 additions & 7 deletions tests/plugins/generic/test_goggle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# third party
import numpy as np
import pkg_resources
import pytest
from generic_helpers import generate_fixtures
from sklearn.datasets import load_diabetes, load_iris
Expand All @@ -12,44 +13,71 @@
from synthcity.plugins.generic.plugin_goggle import plugin
from synthcity.utils.serialization import load, save

is_missing_goggle_deps = plugin is None

plugin_name = "goggle"
plugin_args = {
"n_iter": 10,
"device": "cpu",
}

if not is_missing_goggle_deps:
goggle_dependencies = {"dgl", "torch-scatter", "torch-sparse", "torch-geometric"}
installed = {pkg.key for pkg in pkg_resources.working_set}
is_missing_goggle_deps = len(goggle_dependencies - installed) > 0


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin",
generate_fixtures(plugin_name, plugin),
)
def test_plugin_sanity(test_plugin: Plugin) -> None:
assert test_plugin is not None


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin",
generate_fixtures(plugin_name, plugin),
)
def test_plugin_name(test_plugin: Plugin) -> None:
assert test_plugin.name() == plugin_name


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin",
generate_fixtures(plugin_name, plugin),
)
def test_plugin_type(test_plugin: Plugin) -> None:
assert test_plugin.type() == "generic"


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin",
generate_fixtures(plugin_name, plugin),
)
def test_plugin_hyperparams(test_plugin: Plugin) -> None:
assert len(test_plugin.hyperparameter_space()) == 9


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
"test_plugin",
generate_fixtures(plugin_name, plugin, plugin_args),
)
def test_plugin_fit(test_plugin: Plugin) -> None:
Xraw, y = load_diabetes(return_X_y=True, as_frame=True)
Xraw["target"] = y
test_plugin.fit(GenericDataLoader(Xraw))


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
"test_plugin",
generate_fixtures(plugin_name, plugin, plugin_args),
)
@pytest.mark.parametrize("serialize", [True, False])
def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
Expand Down Expand Up @@ -78,8 +106,10 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
assert (X_gen1.numpy() != X_gen3.numpy()).any()


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
"test_plugin",
generate_fixtures(plugin_name, plugin, plugin_args),
)
def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None:
X, y = load_iris(as_frame=True, return_X_y=True)
Expand All @@ -105,12 +135,15 @@ def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None:
assert list(X_gen.columns) == list(X.columns)


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
def test_sample_hyperparams() -> None:
assert plugin is not None
for i in range(100):
args = plugin.sample_hyperparameters()
assert plugin(**args) is not None


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.slow
@pytest.mark.parametrize(
"compress_dataset,decoder_arch",
Expand All @@ -130,6 +163,7 @@ def test_eval_performance_goggle(compress_dataset: bool, decoder_arch: str) -> N
Xraw["target"] = y
X = GenericDataLoader(Xraw)

assert plugin is not None
for retry in range(2):
test_plugin = plugin(
n_iter=5000,
Expand Down