Skip to content

Commit

Permalink
uncomment registry test
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdavis committed Jan 3, 2024
1 parent f349f62 commit 3c147dc
Showing 1 changed file with 57 additions and 58 deletions.
115 changes: 57 additions & 58 deletions tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import platform
from copy import copy
from pathlib import Path
from typing import Any, List

# third party
import pytest
Expand All @@ -13,17 +14,15 @@
# synthcity absolute
from synthcity.benchmark import Benchmarks
from synthcity.benchmark.utils import get_json_serializable_kwargs

# from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import ( # DataLoader,
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import (
DataLoader,
GenericDataLoader,
SurvivalAnalysisDataLoader,
)

# from typing import Any, List
# from synthcity.plugins.core.distribution import Distribution
# from synthcity.plugins.core.plugin import Plugin
# from synthcity.plugins.core.schema import Schema
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema


def test_benchmark_sanity() -> None:
Expand Down Expand Up @@ -295,53 +294,53 @@ def test_benchmark_workspace_cache() -> None:
assert augment_generator_file.exists()


# def test_benchmark_added_plugin() -> None:
# X, y = load_iris(return_X_y=True, as_frame=True)
# X["target"] = y

# class DummyCopyDataPlugin(Plugin):
# """Dummy plugin for debugging."""

# def __init__(self, **kwargs: Any) -> None:
# super().__init__(**kwargs)

# @staticmethod
# def name() -> str:
# return "copy_data"

# @staticmethod
# def type() -> str:
# return "debug"

# @staticmethod
# def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Distribution]:
# return []

# def _fit(
# self, X: DataLoader, *args: Any, **kwargs: Any
# ) -> "DummyCopyDataPlugin":
# self.features_count = X.shape[1]
# self.X = X
# return self

# def _generate(
# self, count: int, syn_schema: Schema, **kwargs: Any
# ) -> DataLoader:
# return self.X.sample(count)

# generators = Plugins()
# # Add the new plugin to the collection
# generators.add("copy_data", DummyCopyDataPlugin)

# score = Benchmarks.evaluate(
# [
# ("copy_data", "copy_data", {}),
# ],
# GenericDataLoader(X, target_column="target"),
# metrics={
# "performance": [
# "linear_model",
# ]
# },
# )
# assert "copy_data" in score
def test_benchmark_added_plugin() -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
X["target"] = y

class DummyCopyDataPlugin(Plugin):
"""Dummy plugin for debugging."""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@staticmethod
def name() -> str:
return "copy_data"

@staticmethod
def type() -> str:
return "debug"

@staticmethod
def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Distribution]:
return []

def _fit(
self, X: DataLoader, *args: Any, **kwargs: Any
) -> "DummyCopyDataPlugin":
self.features_count = X.shape[1]
self.X = X
return self

def _generate(
self, count: int, syn_schema: Schema, **kwargs: Any
) -> DataLoader:
return self.X.sample(count)

generators = Plugins()
# Add the new plugin to the collection
generators.add("copy_data", DummyCopyDataPlugin)

score = Benchmarks.evaluate(
[
("copy_data", "copy_data", {}),
],
GenericDataLoader(X, target_column="target"),
metrics={
"performance": [
"linear_model",
]
},
)
assert "copy_data" in score

0 comments on commit 3c147dc

Please sign in to comment.