From 9df85147a7d776a2d5a9435c31c70d5da6083ddd Mon Sep 17 00:00:00 2001 From: robsdavis Date: Mon, 3 Apr 2023 16:39:48 +0100 Subject: [PATCH 01/13] Skip goggle tests if dependencies not installed and pin torch<2.0 --- .github/workflows/test_pr.yml | 2 +- .github/workflows/test_tutorials.yml | 2 +- setup.cfg | 2 +- tests/plugins/generic/generic_helpers.py | 9 +++- tests/plugins/generic/test_goggle.py | 62 +++++++++++++++++++++--- 5 files changed, 64 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 35a8b2d0..9f958b10 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -57,6 +57,6 @@ jobs: pip install -r prereq.txt pip install --upgrade pip - pip install .[all] + pip install .[testing] - name: Test with pytest run: pytest -vvvs -m "not slow" --durations=50 diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index 212e79e2..4be2288b 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -35,7 +35,7 @@ jobs: pip install -r prereq.txt pip install --upgrade pip - pip install .[all] + pip install .[testing] python -m pip install ipykernel python -m ipykernel install --user diff --git a/setup.cfg b/setup.cfg index de0ac067..33f54e16 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,7 @@ install_requires = scikit-learn>=1.0 nflows>=0.14 pandas>=1.3 - torch>=1.10.0 + torch<2.0.0 numpy>=1.20 lifelines>=0.27 opacus>=1.3 diff --git a/tests/plugins/generic/generic_helpers.py b/tests/plugins/generic/generic_helpers.py index 1b9453e0..bc8719a7 100644 --- a/tests/plugins/generic/generic_helpers.py +++ b/tests/plugins/generic/generic_helpers.py @@ -10,7 +10,9 @@ from synthcity.utils.serialization import load, save -def generate_fixtures(name: str, plugin: Type, plugin_args: Dict = {}) -> List: +def generate_fixtures( + name: str, plugin: Type, plugin_args: Dict = {}, use_dummy_fixtures: bool = False +) -> List: def from_api() -> Plugin: return Plugins().get(name, **plugin_args) @@ -21,7 +23,10 @@ def from_serde() -> Plugin: buff = save(plugin(**plugin_args)) return load(buff) - return [from_api(), from_module(), from_serde()] + if use_dummy_fixtures: + return [None, None, None] + else: + return [from_api(), from_module(), from_serde()] def get_airfoil_dataset() -> pd.DataFrame: diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 0fd89099..1e91efc1 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -1,5 +1,7 @@ +# Standard library imports # third party import numpy as np +import pkg_resources import pytest from generic_helpers import generate_fixtures from sklearn.datasets import load_diabetes, load_iris @@ -9,38 +11,72 @@ from synthcity.plugins import Plugin from synthcity.plugins.core.constraints import Constraints from synthcity.plugins.core.dataloader import GenericDataLoader -from synthcity.plugins.generic.plugin_goggle import plugin from synthcity.utils.serialization import load, save +is_missing_goggle_deps = False +try: + # synthcity absolute + from synthcity.plugins.generic.plugin_goggle import plugin +except ImportError: + plugin = None + is_missing_goggle_deps = True + plugin_name = "goggle" plugin_args = { "n_iter": 10, "device": "cpu", } +if not is_missing_goggle_deps: + goggle_dependencies = {"dgl", "torch-scatter", "torch-sparse", "torch-geometric"} + installed = {pkg.key for pkg in pkg_resources.working_set} + is_missing_goggle_deps = len(goggle_dependencies - installed) > 0 + +print(is_missing_goggle_deps) + -@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") +@pytest.mark.parametrize( + "test_plugin", + generate_fixtures(plugin_name, plugin, use_dummy_fixtures=is_missing_goggle_deps), +) def test_plugin_sanity(test_plugin: Plugin) -> None: assert test_plugin is not None -@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") +@pytest.mark.parametrize( + "test_plugin", + generate_fixtures(plugin_name, plugin, use_dummy_fixtures=is_missing_goggle_deps), +) def test_plugin_name(test_plugin: Plugin) -> None: assert test_plugin.name() == plugin_name -@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") +@pytest.mark.parametrize( + "test_plugin", + generate_fixtures(plugin_name, plugin, use_dummy_fixtures=is_missing_goggle_deps), +) def test_plugin_type(test_plugin: Plugin) -> None: assert test_plugin.type() == "generic" -@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") +@pytest.mark.parametrize( + "test_plugin", + generate_fixtures(plugin_name, plugin, use_dummy_fixtures=is_missing_goggle_deps), +) def test_plugin_hyperparams(test_plugin: Plugin) -> None: assert len(test_plugin.hyperparameter_space()) == 9 +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) + "test_plugin", + generate_fixtures( + plugin_name, plugin, plugin_args, use_dummy_fixtures=is_missing_goggle_deps + ), ) def test_plugin_fit(test_plugin: Plugin) -> None: Xraw, y = load_diabetes(return_X_y=True, as_frame=True) @@ -48,8 +84,12 @@ def test_plugin_fit(test_plugin: Plugin) -> None: test_plugin.fit(GenericDataLoader(Xraw)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) + "test_plugin", + generate_fixtures( + plugin_name, plugin, plugin_args, use_dummy_fixtures=is_missing_goggle_deps + ), ) @pytest.mark.parametrize("serialize", [True, False]) def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: @@ -78,8 +118,12 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: assert (X_gen1.numpy() != X_gen3.numpy()).any() +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) + "test_plugin", + generate_fixtures( + plugin_name, plugin, plugin_args, use_dummy_fixtures=is_missing_goggle_deps + ), ) def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None: X, y = load_iris(as_frame=True, return_X_y=True) @@ -105,12 +149,14 @@ def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None: assert list(X_gen.columns) == list(X.columns) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") def test_sample_hyperparams() -> None: for i in range(100): args = plugin.sample_hyperparameters() assert plugin(**args) is not None +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.slow @pytest.mark.parametrize( "compress_dataset,decoder_arch", From 641c527580f7630694ef7529ba3665f13487c188 Mon Sep 17 00:00:00 2001 From: robsdavis Date: Mon, 3 Apr 2023 16:52:01 +0100 Subject: [PATCH 02/13] Pass pre-commit --- tests/plugins/generic/test_goggle.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 1e91efc1..f4dd0c7b 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -1,4 +1,3 @@ -# Standard library imports # third party import numpy as np import pkg_resources @@ -18,7 +17,14 @@ # synthcity absolute from synthcity.plugins.generic.plugin_goggle import plugin except ImportError: - plugin = None + """ + Import dummy_sampler, but don't use it. + A valid plugin is required for generate_fixtures, but all tests should be skipped, if + the goggle dependencies are missing. + """ + # synthcity absolute + from synthcity.plugins.generic.plugin_dummy_sampler import plugin + is_missing_goggle_deps = True plugin_name = "goggle" From 3e7274aa4a1809d2fbd7f889f36c2b6387b93999 Mon Sep 17 00:00:00 2001 From: robsdavis Date: Mon, 3 Apr 2023 17:16:43 +0100 Subject: [PATCH 03/13] install goggle dependencies in workflows --- .github/workflows/test_pr.yml | 2 +- .github/workflows/test_tutorials.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 9f958b10..35a8b2d0 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -57,6 +57,6 @@ jobs: pip install -r prereq.txt pip install --upgrade pip - pip install .[testing] + pip install .[all] - name: Test with pytest run: pytest -vvvs -m "not slow" --durations=50 diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index 4be2288b..212e79e2 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -35,7 +35,7 @@ jobs: pip install -r prereq.txt pip install --upgrade pip - pip install .[testing] + pip install .[all] python -m pip install ipykernel python -m ipykernel install --user From 6a957bfcc116539c095892a53b36fe3512ec68b8 Mon Sep 17 00:00:00 2001 From: robsdavis Date: Mon, 3 Apr 2023 17:47:00 +0100 Subject: [PATCH 04/13] match pytorch version in prereq to version in setup.cfg --- prereq.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prereq.txt b/prereq.txt index dbdb5914..5ead92b8 100644 --- a/prereq.txt +++ b/prereq.txt @@ -1,3 +1,3 @@ numpy -torch +torch<2.0.0 tsai From 9b8e99659a77a9940687e095d270beec41b6c406 Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Tue, 4 Apr 2023 06:44:01 +0300 Subject: [PATCH 05/13] more depends issues --- setup.cfg | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 33f54e16..6e75f886 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,8 +34,8 @@ python_requires = >=3.7 install_requires = scikit-learn>=1.0 nflows>=0.14 - pandas>=1.3 - torch<2.0.0 + pandas>=1.3,<2.0 + torch>=1.10 numpy>=1.20 lifelines>=0.27 opacus>=1.3 @@ -92,6 +92,7 @@ testing = click goggle = + torch<2.0 dgl torch_geometric torch_sparse From 5f4ba1917a9d48c1d414fbea1d108ff5dac3a155 Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Tue, 4 Apr 2023 07:04:26 +0300 Subject: [PATCH 06/13] cleanup prereq --- prereq.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prereq.txt b/prereq.txt index 5ead92b8..dbdb5914 100644 --- a/prereq.txt +++ b/prereq.txt @@ -1,3 +1,3 @@ numpy -torch<2.0.0 +torch tsai From 2e69a36f081f79e46b4c0816a4febd9aaafdd0e0 Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Tue, 4 Apr 2023 07:12:31 +0300 Subject: [PATCH 07/13] reorganize workflows --- .github/workflows/test_full.yml | 11 ++++++++--- .github/workflows/test_pr.yml | 12 ++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index 07727c4a..7c3003f9 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -30,6 +30,11 @@ jobs: pip install -r prereq.txt pip install --upgrade pip - pip install .[all] - - name: Test with pytest - run: pytest -vvvs --durations=50 + - name: Test Core + run: | + pip install .[testing] + pytest -vvvs --durations=50 + - name: Test GOGGLE + run: | + pip install .[testing,goggle] + pytest -vvvs -k goggle --durations=50 diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 35a8b2d0..e60afb9d 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -55,8 +55,12 @@ jobs: - name: Install dependencies run: | pip install -r prereq.txt - pip install --upgrade pip - pip install .[all] - - name: Test with pytest - run: pytest -vvvs -m "not slow" --durations=50 + - name: Test Core + run: | + pip install .[testing] + pytest -vvvs -m "not slow" --durations=50 + - name: Test GOGGLE + run: | + pip install .[testing,goggle] + pytest -vvvs -k goggle From f9307c952c4ca791c750f2531a63652ea472e895 Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Tue, 4 Apr 2023 07:30:06 +0300 Subject: [PATCH 08/13] handle modules that cannot be loaded --- src/synthcity/plugins/core/plugin.py | 24 +++++++++++++++---- .../plugins/generic/plugin_goggle.py | 14 +++++++++-- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/synthcity/plugins/core/plugin.py b/src/synthcity/plugins/core/plugin.py index 33529764..d4049bed 100644 --- a/src/synthcity/plugins/core/plugin.py +++ b/src/synthcity/plugins/core/plugin.py @@ -548,13 +548,16 @@ def __init__(self, plugins: list, expected_type: Type, categories: list) -> None self._available_plugins = {} for plugin in plugins: stem = Path(plugin).stem.split("plugin_")[-1] + cls = self._load_single_plugin_impl(plugin) + if cls is None: + continue self._available_plugins[stem] = plugin self._expected_type = expected_type self._categories = categories @validate_arguments - def _load_single_plugin(self, plugin_name: str) -> None: - """Helper for loading a single plugin""" + def _load_single_plugin_impl(self, plugin_name: str) -> Optional[Type]: + """Helper for loading a single plugin implementation""" plugin = Path(plugin_name) name = plugin.stem ptype = plugin.parent.name @@ -579,6 +582,10 @@ def _load_single_plugin(self, plugin_name: str) -> None: spec.loader.exec_module(mod) cls = mod.plugin + if cls is None: + log.critical(f"module disabled: {plugin_name}") + return None + failed = False break except BaseException as e: @@ -587,10 +594,19 @@ def _load_single_plugin(self, plugin_name: str) -> None: if failed: log.critical(f"module {name} load failed") - return + return None + + return cls + + @validate_arguments + def _load_single_plugin(self, plugin_name: str) -> bool: + """Helper for loading a single plugin""" + cls = self._load_single_plugin_impl(plugin_name) + if cls is None: + return False - log.debug(f"Loaded plugin {cls.type()} - {cls.name()}") self.add(cls.name(), cls) + return True def list(self) -> List[str]: """Get all the available plugins.""" diff --git a/src/synthcity/plugins/generic/plugin_goggle.py b/src/synthcity/plugins/generic/plugin_goggle.py index e829cd3e..3e982f06 100644 --- a/src/synthcity/plugins/generic/plugin_goggle.py +++ b/src/synthcity/plugins/generic/plugin_goggle.py @@ -23,11 +23,18 @@ FloatDistribution, IntegerDistribution, ) -from synthcity.plugins.core.models.tabular_goggle import TabularGoggle from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema from synthcity.utils.constants import DEVICE +try: + # synthcity absolute + from synthcity.plugins.core.models.tabular_goggle import TabularGoggle + + module_disabled = False +except ImportError: + module_disabled = True + class GOGGLEPlugin(Plugin): """ @@ -248,4 +255,7 @@ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFra return self._safe_generate(self.model.generate, count, syn_schema) -plugin = GOGGLEPlugin +if module_disabled: + plugin = None +else: + plugin = GOGGLEPlugin From 2535817065c822f777eb692cd5c69a09bd080099 Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Tue, 4 Apr 2023 07:39:54 +0300 Subject: [PATCH 09/13] fix linting errors --- tests/plugins/generic/generic_helpers.py | 16 +++++----- tests/plugins/generic/test_goggle.py | 38 ++++++------------------ 2 files changed, 17 insertions(+), 37 deletions(-) diff --git a/tests/plugins/generic/generic_helpers.py b/tests/plugins/generic/generic_helpers.py index bc8719a7..9abd23e9 100644 --- a/tests/plugins/generic/generic_helpers.py +++ b/tests/plugins/generic/generic_helpers.py @@ -1,5 +1,5 @@ # stdlib -from typing import Dict, List, Type +from typing import Dict, List, Optional, Type # third party import pandas as pd @@ -11,22 +11,22 @@ def generate_fixtures( - name: str, plugin: Type, plugin_args: Dict = {}, use_dummy_fixtures: bool = False + name: str, plugin: Optional[Type], plugin_args: Dict = {} ) -> List: + if plugin is None: + return [] + def from_api() -> Plugin: return Plugins().get(name, **plugin_args) def from_module() -> Plugin: - return plugin(**plugin_args) + return plugin(**plugin_args) # type: ignore def from_serde() -> Plugin: - buff = save(plugin(**plugin_args)) + buff = save(plugin(**plugin_args)) # type: ignore return load(buff) - if use_dummy_fixtures: - return [None, None, None] - else: - return [from_api(), from_module(), from_serde()] + return [from_api(), from_module(), from_serde()] def get_airfoil_dataset() -> pd.DataFrame: diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index f4dd0c7b..da391911 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -10,22 +10,10 @@ from synthcity.plugins import Plugin from synthcity.plugins.core.constraints import Constraints from synthcity.plugins.core.dataloader import GenericDataLoader +from synthcity.plugins.generic.plugin_goggle import plugin from synthcity.utils.serialization import load, save -is_missing_goggle_deps = False -try: - # synthcity absolute - from synthcity.plugins.generic.plugin_goggle import plugin -except ImportError: - """ - Import dummy_sampler, but don't use it. - A valid plugin is required for generate_fixtures, but all tests should be skipped, if - the goggle dependencies are missing. - """ - # synthcity absolute - from synthcity.plugins.generic.plugin_dummy_sampler import plugin - - is_missing_goggle_deps = True +is_missing_goggle_deps = plugin is None plugin_name = "goggle" plugin_args = { @@ -38,13 +26,11 @@ installed = {pkg.key for pkg in pkg_resources.working_set} is_missing_goggle_deps = len(goggle_dependencies - installed) > 0 -print(is_missing_goggle_deps) - @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", - generate_fixtures(plugin_name, plugin, use_dummy_fixtures=is_missing_goggle_deps), + generate_fixtures(plugin_name, plugin), ) def test_plugin_sanity(test_plugin: Plugin) -> None: assert test_plugin is not None @@ -53,7 +39,7 @@ def test_plugin_sanity(test_plugin: Plugin) -> None: @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", - generate_fixtures(plugin_name, plugin, use_dummy_fixtures=is_missing_goggle_deps), + generate_fixtures(plugin_name, plugin), ) def test_plugin_name(test_plugin: Plugin) -> None: assert test_plugin.name() == plugin_name @@ -62,7 +48,7 @@ def test_plugin_name(test_plugin: Plugin) -> None: @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", - generate_fixtures(plugin_name, plugin, use_dummy_fixtures=is_missing_goggle_deps), + generate_fixtures(plugin_name, plugin), ) def test_plugin_type(test_plugin: Plugin) -> None: assert test_plugin.type() == "generic" @@ -71,7 +57,7 @@ def test_plugin_type(test_plugin: Plugin) -> None: @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", - generate_fixtures(plugin_name, plugin, use_dummy_fixtures=is_missing_goggle_deps), + generate_fixtures(plugin_name, plugin), ) def test_plugin_hyperparams(test_plugin: Plugin) -> None: assert len(test_plugin.hyperparameter_space()) == 9 @@ -80,9 +66,7 @@ def test_plugin_hyperparams(test_plugin: Plugin) -> None: @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", - generate_fixtures( - plugin_name, plugin, plugin_args, use_dummy_fixtures=is_missing_goggle_deps - ), + generate_fixtures(plugin_name, plugin, plugin_args), ) def test_plugin_fit(test_plugin: Plugin) -> None: Xraw, y = load_diabetes(return_X_y=True, as_frame=True) @@ -93,9 +77,7 @@ def test_plugin_fit(test_plugin: Plugin) -> None: @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", - generate_fixtures( - plugin_name, plugin, plugin_args, use_dummy_fixtures=is_missing_goggle_deps - ), + generate_fixtures(plugin_name, plugin, plugin_args), ) @pytest.mark.parametrize("serialize", [True, False]) def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: @@ -127,9 +109,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( "test_plugin", - generate_fixtures( - plugin_name, plugin, plugin_args, use_dummy_fixtures=is_missing_goggle_deps - ), + generate_fixtures(plugin_name, plugin, plugin_args), ) def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None: X, y = load_iris(as_frame=True, return_X_y=True) From 4eb1bdafe1844336ca7b970f9b33f039d645cedb Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Tue, 4 Apr 2023 07:43:54 +0300 Subject: [PATCH 10/13] more linting --- tests/plugins/generic/test_goggle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index da391911..9b194ae0 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -137,6 +137,7 @@ def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None: @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") def test_sample_hyperparams() -> None: + assert plugin is not None for i in range(100): args = plugin.sample_hyperparameters() assert plugin(**args) is not None @@ -162,6 +163,7 @@ def test_eval_performance_goggle(compress_dataset: bool, decoder_arch: str) -> N Xraw["target"] = y X = GenericDataLoader(Xraw) + assert plugin is not None for retry in range(2): test_plugin = plugin( n_iter=5000, From d5b81bb51e4f9011f3b3deb23b5477f60ed02d4d Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Tue, 4 Apr 2023 07:58:10 +0300 Subject: [PATCH 11/13] rename PR tests --- .github/workflows/test_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index e60afb9d..077461c1 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -1,4 +1,4 @@ -name: Tests Python +name: Tests Fast Python on: push: From 77b7a46d94d9a91418249a3519c998e8bf23fabf Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Tue, 4 Apr 2023 08:30:02 +0300 Subject: [PATCH 12/13] cleanup workflows --- .github/workflows/test_full.yml | 1 - .github/workflows/test_pr.yml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index 7c3003f9..e3265692 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -28,7 +28,6 @@ jobs: - name: Install dependencies run: | pip install -r prereq.txt - pip install --upgrade pip - name: Test Core run: | diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 077461c1..333e4b4f 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -63,4 +63,4 @@ jobs: - name: Test GOGGLE run: | pip install .[testing,goggle] - pytest -vvvs -k goggle + pytest -vvvs -m "not slow" -k goggle From 098baa7c1f0f0b2b95dac65453651bfcd7610d37 Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Tue, 4 Apr 2023 08:39:46 +0300 Subject: [PATCH 13/13] dependency mess --- .github/workflows/test_pr.yml | 4 ++-- prereq.txt | 2 +- setup.cfg | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 333e4b4f..2c907433 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -59,8 +59,8 @@ jobs: - name: Test Core run: | pip install .[testing] - pytest -vvvs -m "not slow" --durations=50 + pytest -vvvsx -m "not slow" --durations=50 - name: Test GOGGLE run: | pip install .[testing,goggle] - pytest -vvvs -m "not slow" -k goggle + pytest -vvvsx -m "not slow" -k goggle diff --git a/prereq.txt b/prereq.txt index dbdb5914..0d7eb1f0 100644 --- a/prereq.txt +++ b/prereq.txt @@ -1,3 +1,3 @@ numpy -torch +torch<2.0 tsai diff --git a/setup.cfg b/setup.cfg index 6e75f886..96a67f09 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,7 @@ install_requires = scikit-learn>=1.0 nflows>=0.14 pandas>=1.3,<2.0 - torch>=1.10 + torch>=1.10,<2.0 numpy>=1.20 lifelines>=0.27 opacus>=1.3 @@ -92,7 +92,6 @@ testing = click goggle = - torch<2.0 dgl torch_geometric torch_sparse