Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added function to make kwargs json serializable in Benchmarks.evaluate #157

Merged
merged 2 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/synthcity/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# synthcity absolute
import synthcity.logger as log
from synthcity.benchmark.utils import augment_data
from synthcity.benchmark.utils import augment_data, get_json_serializable_kwargs
from synthcity.metrics import Metrics
from synthcity.metrics.scores import ScoreEvaluator
from synthcity.plugins import Plugins
Expand Down Expand Up @@ -129,7 +129,10 @@ def evaluate(

kwargs_hash = ""
if len(kwargs) > 0:
kwargs_hash_raw = json.dumps(kwargs, sort_keys=True).encode()
serializable_kwargs = get_json_serializable_kwargs(kwargs)
kwargs_hash_raw = json.dumps(
serializable_kwargs, sort_keys=True
).encode()
hash_object = hashlib.sha256(kwargs_hash_raw)
kwargs_hash = hash_object.hexdigest()

Expand Down
15 changes: 14 additions & 1 deletion src/synthcity/benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
import math
from copy import copy
from copy import copy, deepcopy
from pathlib import Path
from typing import Any, Dict, Optional

# third party
Expand All @@ -14,6 +15,18 @@
from synthcity.plugins.core.dataloader import DataLoader


def get_json_serializable_kwargs(kwargs: Dict) -> Dict:
"""
This function should take the kwargs for Benchmarks.evaluate and makes them serializable with json.dumps.
Currently it only handles pathlib.Path -> str.
"""
serializable_kwargs = deepcopy(kwargs)
for k, v in serializable_kwargs.items():
if isinstance(v, Path):
serializable_kwargs[k] = str(serializable_kwargs[k])
return serializable_kwargs


def calculate_fair_aug_sample_size(
X_train: pd.DataFrame,
fairness_column: Optional[str], # a categorical column of K levels
Expand Down
9 changes: 8 additions & 1 deletion tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# synthcity absolute
from synthcity.benchmark import Benchmarks
from synthcity.benchmark.utils import get_json_serializable_kwargs
from synthcity.plugins.core.dataloader import (
GenericDataLoader,
SurvivalAnalysisDataLoader,
Expand Down Expand Up @@ -219,8 +220,14 @@ def test_benchmark_workspace_cache() -> None:

testcase = "test1"
plugin = "uniform_sampler"
kwargs = {"workspace": Path("workspace_test")}

kwargs_hash = ""
if len(kwargs) > 0:
serializable_kwargs = get_json_serializable_kwargs(kwargs)
kwargs_hash_raw = json.dumps(serializable_kwargs, sort_keys=True).encode()
hash_object = hashlib.sha256(kwargs_hash_raw)
kwargs_hash = hash_object.hexdigest()

augmentation_arguments = {
"augmentation_rule": "equal",
Expand All @@ -238,7 +245,7 @@ def test_benchmark_workspace_cache() -> None:

Benchmarks.evaluate(
[
(testcase, plugin, {}),
(testcase, plugin, kwargs),
],
X,
task_type="survival_analysis",
Expand Down