Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full tests #294

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c8b8c13
update workflow
robsdavis Sep 17, 2024
8dda640
temporarily suppress short tests
robsdavis Sep 17, 2024
cb59180
temporarily suppress short tests
robsdavis Sep 17, 2024
d536d2e
migrate pydantic (#295)
robsdavis Oct 1, 2024
6043e1b
clean up
robsdavis Oct 1, 2024
208c6be
Swap keep alive to approach (#309)
robsdavis Jan 6, 2025
7b948c7
Keep alive (#314)
robsdavis Jan 6, 2025
d6da33a
use all available datasets for computing the encoding of the categori…
Davee02 Jan 7, 2025
21f8e30
add logging on failed tabular goggle import (#316)
robsdavis Jan 7, 2025
3836949
Automated commit by Keepalive Workflow to keep the repository active
Jan 7, 2025
0809db4
update workflow
robsdavis Sep 17, 2024
c95998a
temporarily suppress short tests
robsdavis Sep 17, 2024
3191d99
temporarily suppress short tests
robsdavis Sep 17, 2024
c0a1a0c
Merge branch 'full-tests' of https://github.com/vanderschaarlab/synth…
robsdavis Jan 8, 2025
4f02baa
comment flakey test
robsdavis Jan 8, 2025
5877192
stabilise tab_ddpm internal functions (#317)
robsdavis Jan 8, 2025
555f8e4
update workflow
robsdavis Sep 17, 2024
d055b42
temporarily suppress short tests
robsdavis Sep 17, 2024
d49ce89
temporarily suppress short tests
robsdavis Sep 17, 2024
8772470
Automated commit by Keepalive Workflow to keep the repository active
Jan 7, 2025
7770046
comment flakey test
robsdavis Jan 8, 2025
673c851
Merge branch 'full-tests' of https://github.com/vanderschaarlab/synth…
robsdavis Jan 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/test_all_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: true
- uses: gautamkrishnar/keepalive-workflow@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: true
- uses: gautamkrishnar/keepalive-workflow@v1
ref: main
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand All @@ -30,18 +30,18 @@ jobs:
run: |
python -m pip install -U pip
pip install -r prereq.txt
- name: Limit OpenMP threads
run: |
echo "OMP_NUM_THREADS=2" >> $GITHUB_ENV
- name: Test Core - slow part one
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "slow_1"
- name: Test Core - slow part two
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "slow_2"
- name: Test Core - fast
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "not slow"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: Tests Fast Python
on:
push:
branches: [main, release]
pull_request:
types: [opened, synchronize, reopened]
# pull_request:
# types: [opened, synchronize, reopened]
workflow_dispatch:

jobs:
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: PR Tutorials
on:
push:
branches: [main, release]
pull_request:
types: [opened, synchronize, reopened]
# pull_request:
# types: [opened, synchronize, reopened]
schedule:
- cron: "2 3 * * 4"
workflow_dispatch:
Expand All @@ -20,7 +20,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: true
- uses: gautamkrishnar/keepalive-workflow@v1
ref: main
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand All @@ -40,4 +40,4 @@ jobs:
python -m pip install ipykernel
python -m ipykernel install --user
- name: Run the tutorials
run: python tests/nb_eval.py --nb_dir tutorials/ --tutorial_tests minimal_tests
run: python tests/nb_eval.py --nb_dir tutorials/ --tutorial_tests minimal_tests --timeout 3600
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ install_requires =
tenacity
tqdm
loguru
pydantic<2.0
pydantic>=2.0
cloudpickle
scipy
xgboost<3.0.0
Expand Down
4 changes: 4 additions & 0 deletions src/synthcity/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def evaluate(
strict_augmentation: bool = False,
ad_hoc_augment_vals: Optional[Dict] = None,
use_metric_cache: bool = True,
n_eval_folds: int = 5,
**generate_kwargs: Any,
) -> pd.DataFrame:
"""Benchmark the performance of several algorithms.
Expand Down Expand Up @@ -102,6 +103,8 @@ def evaluate(
A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to None.
use_metric_cache: bool
If the current metric has been previously run and is cached, it will be reused for the experiments. Defaults to True.
n_eval_folds: int
the KFolds used by MetricEvaluators in the benchmarks. Defaults to 5.
plugin_kwargs:
Optional kwargs for each algorithm. Example {"adsgan": {"n_iter": 10}},
"""
Expand Down Expand Up @@ -295,6 +298,7 @@ def evaluate(
task_type=task_type,
workspace=workspace,
use_cache=use_metric_cache,
n_folds=n_eval_folds,
)

mean_score = evaluation["mean"].to_dict()
Expand Down
6 changes: 3 additions & 3 deletions src/synthcity/metrics/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def f() -> None:
"epoch": epoch,
},
workspace / "DomiasMIA_bnaf_checkpoint.pt",
)
) # nosec B614

return f

Expand All @@ -348,7 +348,7 @@ def f() -> None:

log.info("Loading model..")
if (workspace / "checkpoint.pt").exists():
checkpoint = torch.load(workspace / "checkpoint.pt")
checkpoint = torch.load(workspace / "checkpoint.pt") # nosec B614
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])

Expand Down Expand Up @@ -453,7 +453,7 @@ def train(
"epoch": epoch,
},
workspace / "checkpoint.pt",
)
) # nosec B614
log.debug(
f"""
###### Stop training after {epoch + 1} epochs!
Expand Down
18 changes: 13 additions & 5 deletions src/synthcity/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def evaluate(
random_state: int = 0,
workspace: Path = Path("workspace"),
use_cache: bool = True,
n_folds: int = 5,
) -> pd.DataFrame:
"""Core evaluation logic for the metrics

Expand Down Expand Up @@ -202,12 +203,16 @@ def evaluate(

"""
We need to encode the categorical data in the real and synthetic data.
To ensure each category in the two datasets are mapped to the same one hot vector, we merge X_syn into X_gt for computing the encoder.
TODO: Check whether the optional datasets also need to be taking into account when getting the encoder.
To ensure each category in the two datasets are mapped to the same one hot vector, we merge all avalable datasets for computing the encoder.
"""
X_gt_df = X_gt.dataframe()
X_syn_df = X_syn.dataframe()
X_enc = create_from_info(pd.concat([X_gt_df, X_syn_df]), X_gt.info())
all_df = pd.concat([X_gt.dataframe(), X_syn.dataframe()])
if X_train:
all_df = pd.concat([all_df, X_train.dataframe()])
if X_ref_syn:
all_df = pd.concat([all_df, X_ref_syn.dataframe()])
if X_augmented:
all_df = pd.concat([all_df, X_augmented.dataframe()])
X_enc = create_from_info(all_df, X_gt.info())
_, encoders = X_enc.encode()

# now we encode the data
Expand Down Expand Up @@ -238,6 +243,7 @@ def evaluate(
random_state=random_state,
workspace=workspace,
use_cache=use_cache,
n_folds=n_folds,
),
X_gt,
X_augmented,
Expand All @@ -251,6 +257,7 @@ def evaluate(
random_state=random_state,
workspace=workspace,
use_cache=use_cache,
n_folds=n_folds,
),
X_gt,
X_syn,
Expand All @@ -267,6 +274,7 @@ def evaluate(
random_state=random_state,
workspace=workspace,
use_cache=use_cache,
n_folds=n_folds,
),
X_gt.sample(eval_cnt),
X_syn.sample(eval_cnt),
Expand Down
10 changes: 6 additions & 4 deletions src/synthcity/plugins/core/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# third party
import numpy as np
import pandas as pd
from pydantic import BaseModel, validate_arguments, validator
from pydantic import BaseModel, field_validator, validate_arguments

# synthcity absolute
import synthcity.logger as log

Rule = Tuple[str, str, Any] # Define a type alias for clarity


class Constraints(BaseModel):
"""
Expand Down Expand Up @@ -41,10 +43,10 @@ class Constraints(BaseModel):
and thresh is the threshold or data type.
"""

rules: list = []
rules: list[Rule] = []

@validator("rules")
def _validate_rules(cls: Any, rules: List, values: dict, **kwargs: Any) -> List:
@field_validator("rules", mode="before")
def _validate_rules(cls: Any, rules: List) -> List:
supported_ops: list = [
"<",
">=",
Expand Down
Loading
Loading