From 29258667f65abf899479deaecc138281e2ec9fc3 Mon Sep 17 00:00:00 2001 From: robsdavis Date: Thu, 23 Mar 2023 12:33:29 +0000 Subject: [PATCH 1/2] Added function to make kwargs json serializable --- src/synthcity/benchmark/__init__.py | 7 +++++-- src/synthcity/benchmark/utils.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/synthcity/benchmark/__init__.py b/src/synthcity/benchmark/__init__.py index e03e04ca..0cb67fa4 100644 --- a/src/synthcity/benchmark/__init__.py +++ b/src/synthcity/benchmark/__init__.py @@ -15,7 +15,7 @@ # synthcity absolute import synthcity.logger as log -from synthcity.benchmark.utils import augment_data +from synthcity.benchmark.utils import augment_data, get_json_serializable_kwargs from synthcity.metrics import Metrics from synthcity.metrics.scores import ScoreEvaluator from synthcity.plugins import Plugins @@ -129,7 +129,10 @@ def evaluate( kwargs_hash = "" if len(kwargs) > 0: - kwargs_hash_raw = json.dumps(kwargs, sort_keys=True).encode() + serializable_kwargs = get_json_serializable_kwargs(kwargs) + kwargs_hash_raw = json.dumps( + serializable_kwargs, sort_keys=True + ).encode() hash_object = hashlib.sha256(kwargs_hash_raw) kwargs_hash = hash_object.hexdigest() diff --git a/src/synthcity/benchmark/utils.py b/src/synthcity/benchmark/utils.py index 732d7e30..01ce5518 100644 --- a/src/synthcity/benchmark/utils.py +++ b/src/synthcity/benchmark/utils.py @@ -1,6 +1,7 @@ # stdlib import math -from copy import copy +from copy import copy, deepcopy +from pathlib import Path from typing import Any, Dict, Optional # third party @@ -14,6 +15,18 @@ from synthcity.plugins.core.dataloader import DataLoader +def get_json_serializable_kwargs(kwargs: Dict) -> Dict: + """ + This function should take the kwargs for Benchmarks.evaluate and makes them serializable with json.dumps. + Currently it only handles pathlib.Path -> str. + """ + serializable_kwargs = deepcopy(kwargs) + for k, v in serializable_kwargs.items(): + if isinstance(v, Path): + serializable_kwargs[k] = str(serializable_kwargs[k]) + return serializable_kwargs + + def calculate_fair_aug_sample_size( X_train: pd.DataFrame, fairness_column: Optional[str], # a categorical column of K levels From dfd41c4b86bee6d0b57789001f02d4e69a4914c4 Mon Sep 17 00:00:00 2001 From: robsdavis Date: Thu, 23 Mar 2023 12:48:37 +0000 Subject: [PATCH 2/2] Added kwargs caching testing --- tests/benchmarks/test_benchmarks.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/benchmarks/test_benchmarks.py b/tests/benchmarks/test_benchmarks.py index 8c229a24..a08f8926 100644 --- a/tests/benchmarks/test_benchmarks.py +++ b/tests/benchmarks/test_benchmarks.py @@ -12,6 +12,7 @@ # synthcity absolute from synthcity.benchmark import Benchmarks +from synthcity.benchmark.utils import get_json_serializable_kwargs from synthcity.plugins.core.dataloader import ( GenericDataLoader, SurvivalAnalysisDataLoader, @@ -219,8 +220,14 @@ def test_benchmark_workspace_cache() -> None: testcase = "test1" plugin = "uniform_sampler" + kwargs = {"workspace": Path("workspace_test")} kwargs_hash = "" + if len(kwargs) > 0: + serializable_kwargs = get_json_serializable_kwargs(kwargs) + kwargs_hash_raw = json.dumps(serializable_kwargs, sort_keys=True).encode() + hash_object = hashlib.sha256(kwargs_hash_raw) + kwargs_hash = hash_object.hexdigest() augmentation_arguments = { "augmentation_rule": "equal", @@ -238,7 +245,7 @@ def test_benchmark_workspace_cache() -> None: Benchmarks.evaluate( [ - (testcase, plugin, {}), + (testcase, plugin, kwargs), ], X, task_type="survival_analysis",