Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augmentation benchmark #150

Merged
merged 18 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 88 additions & 5 deletions src/synthcity/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import platform
import random
from copy import copy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -14,6 +15,7 @@

# synthcity absolute
import synthcity.logger as log
from synthcity.benchmark.utils import augment_data
from synthcity.metrics import Metrics
from synthcity.metrics.scores import ScoreEvaluator
from synthcity.plugins import Plugins
Expand Down Expand Up @@ -48,8 +50,13 @@ def evaluate(
synthetic_constraints: Optional[Constraints] = None,
synthetic_cache: bool = True,
synthetic_reuse_if_exists: bool = True,
augmented_reuse_if_exists: bool = True,
task_type: str = "classification", # classification, regression, survival_analysis, time_series
workspace: Path = Path("workspace"),
augmentation_rule: str = "equal",
strict_augmentation: bool = False,
ad_hoc_augment_vals: Optional[Dict] = None,
use_metric_cache: bool = True,
**generate_kwargs: Any,
) -> pd.DataFrame:
"""Benchmark the performance of several algorithms.
Expand Down Expand Up @@ -81,6 +88,8 @@ def evaluate(
Enable experiment caching
synthetic_reuse_if_exists: bool
If the current synthetic dataset is cached, it will be reused for the experiments.
augmented_reuse_if_exists: bool
If the current synthetic dataset is cached, it will be reused for the experiments.
task_type: str
The type of problem. Relevant for evaluating the downstream models with the correct metrics. Valid tasks are: "classification", "regression", "survival_analysis", "time_series", "time_series_survival".
workspace: Path
Expand Down Expand Up @@ -115,6 +124,17 @@ def evaluate(
hash_object = hashlib.sha256(kwargs_hash_raw)
kwargs_hash = hash_object.hexdigest()

augmentation_arguments = {
"augmentation_rule": augmentation_rule,
"strict_augmentation": strict_augmentation,
"ad_hoc_augment_vals": ad_hoc_augment_vals,
}
augmentation_arguments_hash_raw = json.dumps(
copy(augmentation_arguments), sort_keys=True
).encode()
augmentation_hash_object = hashlib.sha256(augmentation_arguments_hash_raw)
augmentation_hash = augmentation_hash_object.hexdigest()

repeats_list = list(range(repeats))
random.shuffle(repeats_list)

Expand All @@ -126,14 +146,22 @@ def evaluate(

clear_cache()

cache_file = (
X_syn_cache_file = (
workspace
/ f"{experiment_name}_{testcase}_{plugin}_{kwargs_hash}_{platform.python_version()}_{repeat}.bkp"
)
generator_file = (
workspace
/ f"{experiment_name}_{testcase}_{plugin}_{kwargs_hash}_{platform.python_version()}_generator_{repeat}.bkp"
)
X_augment_cache_file = (
workspace
/ f"{experiment_name}_{testcase}_{plugin}_augmentation_{augmentation_hash}_{kwargs_hash}_{platform.python_version()}_{repeat}.bkp"
)
augment_generator_file = (
workspace
/ f"{experiment_name}_{testcase}_{plugin}_augmentation_{augmentation_hash}_{kwargs_hash}_{platform.python_version()}_generator_{repeat}.bkp"
)

log.info(
f"[testcase] Experiment repeat: {repeat} task type: {task_type} Train df hash = {experiment_name}"
Expand All @@ -152,8 +180,8 @@ def evaluate(
if synthetic_cache:
save_to_file(generator_file, generator)

if cache_file.exists() and synthetic_reuse_if_exists:
X_syn = load_from_file(cache_file)
if X_syn_cache_file.exists() and synthetic_reuse_if_exists:
X_syn = load_from_file(X_syn_cache_file)
else:
try:
X_syn = generator.generate(
Expand All @@ -168,13 +196,68 @@ def evaluate(
continue

if synthetic_cache:
save_to_file(cache_file, X_syn)
save_to_file(X_syn_cache_file, X_syn)

# Augmentation
if metrics and any(
"augmentation" in metric
for metric in [x for v in metrics.values() for x in v]
):
if augment_generator_file.exists() and augmented_reuse_if_exists:
augment_generator = load_from_file(augment_generator_file)
else:
augment_generator = Plugins(categories=plugin_cats).get(
plugin,
**kwargs,
)
try:
if not X.get_fairness_column():
raise ValueError(
"To use the augmentation metrics, `fairness_column` must be set to a string representing the name of a column in the DataLoader."
)
augment_generator.fit(
X.train(),
cond=X.train()[X.get_fairness_column()],
)
except BaseException as e:
log.critical(
f"[{plugin}][take {repeat}] failed to fit augmentation generator: {e}"
)
continue
if synthetic_cache:
save_to_file(augment_generator_file, augment_generator)

if X_augment_cache_file.exists() and augmented_reuse_if_exists:
X_augmented = load_from_file(X_augment_cache_file)
else:
try:
X_augmented = augment_data(
X.train(),
augment_generator,
rule=augmentation_rule,
strict=strict_augmentation,
ad_hoc_augment_vals=ad_hoc_augment_vals,
**generate_kwargs,
)
if len(X_augmented) == 0:
raise RuntimeError("Plugin failed to generate data")
except BaseException as e:
log.critical(
f"[{plugin}][take {repeat}] failed to generate augmentation data: {e}"
)
continue
if synthetic_cache:
save_to_file(X_augment_cache_file, X_augmented)
else:
X_augmented = None
evaluation = Metrics.evaluate(
X_test if X_test is not None else X,
X_test if X_test is not None else X.test(),
X_syn,
X_augmented,
metrics=metrics,
task_type=task_type,
workspace=workspace,
use_cache=use_metric_cache,
)

mean_score = evaluation["mean"].to_dict()
Expand Down
199 changes: 199 additions & 0 deletions src/synthcity/benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# stdlib
import math
from copy import copy
from typing import Any, Dict, Optional

# third party
import numpy as np
import pandas as pd
from pydantic import validate_arguments
from typing_extensions import Literal

# synthcity absolute
from synthcity.plugins.core.constraints import Constraints
from synthcity.plugins.core.dataloader import DataLoader


def calculate_fair_aug_sample_size(
X_train: pd.DataFrame,
fairness_column: Optional[str], # a categorical column of K levels
rule: Literal[
"equal", "log", "ad-hoc"
], # TODO: Confirm are there any more methods to include
ad_hoc_augment_vals: Optional[
Dict[Any, int]
] = None, # Only required for rule == "ad-hoc"
) -> Dict:
"""Calculate how many samples to augment.

Args:
X_train (pd.DataFrame): The real dataset to be augmented.
fairness_column (str): The column name of the column to test the fairness of a downstream model with respect to.
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. If using rule="ad-hoc" this function returns ad_hoc_augment_vals, otherwise this parameter is ignored. Defaults to {}.

Returns:
Dict: A dictionary containing the number of each class to augment the real data with.
"""

# the majority class is unchanged
if rule == "equal":
# number of sample will be the same for each value in the fairness column after augmentation
# N_aug(i) = N_ang(j) for all i and j in value in the fairness column
fairness_col_counts = X_train[fairness_column].value_counts()
majority_size = fairness_col_counts.max()
augmentation_counts = {
fair_col_val: (majority_size - fairness_col_counts.loc[fair_col_val])
for fair_col_val in fairness_col_counts.index
}
elif rule == "log":
# number of samples in aug data will be proportional to the log frequency in the real data.
# Note: taking the log makes the distribution more even.
# N_aug(i) is proportional to log(N_real(i))
fairness_col_counts = X_train[fairness_column].value_counts()
majority_size = fairness_col_counts.max()
log_coefficient = majority_size / math.log(majority_size)

augmentation_counts = {
fair_col_val: (
majority_size - round(math.log(fair_col_count) * log_coefficient)
)
for fair_col_val, fair_col_count in fairness_col_counts.items()
}
elif rule == "ad-hoc":
# use user-specified values to augment
if not ad_hoc_augment_vals:
raise ValueError(
"When augmenting with an `ad-hoc` method, ad_hoc_augment_vals must be a dictionary, where the dictionary keys are the values of the fairness_column and the dictionary values are the number of records to augment."
)
else:
if not set(ad_hoc_augment_vals.keys()).issubset(
set(X_train[fairness_column].values)
):
print(set(X_train[fairness_column].values))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't leave prints in the code. use log if the logs are needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted.

print(set(ad_hoc_augment_vals.keys()))
raise ValueError(
"ad_hoc_augment_vals must be a dictionary, where the dictionary keys are the values of the fairness_column and the dictionary values are the number of records to augment."
)
elif set(X_train[fairness_column].values) != set(
ad_hoc_augment_vals.keys()
):
ad_hoc_augment_vals = {
k: v
for k, v in ad_hoc_augment_vals.items()
if k in set(X_train[fairness_column].values)
}

augmentation_counts = ad_hoc_augment_vals

return augmentation_counts


@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _generate_synthetic_data(
X_train: DataLoader,
augment_generator: Any,
strict: bool = True,
rule: Literal["equal", "log", "ad-hoc"] = "equal",
ad_hoc_augment_vals: Optional[
Dict[Any, int]
] = None, # Only required for rule == "ad-hoc"
synthetic_constraints: Optional[Constraints] = None,
**generate_kwargs: Any,
) -> pd.DataFrame:

"""Generates synthetic data

Args:
X_train (pd.DataFrame): The dataset used to train the downstream model.
y_train (Union[pd.Series, pd.DataFrame]): The data labels for `X_train`. This is used to train the downstream model.
fairness_column (str): The column name of the column to test the fairness of a downstream model with respect to.
target_column (str): The column name of the label column.
syn_model_name (str): The name of the synthetic model plugin to use to generate the synthetic data.
strict (bool, optional): Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False.
rule (Literal["equal", "log", "ad-hoc"): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal".
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to {}.
random_state (int, optional): The random state to seed the synthetic data generation. Defaults to 42.

Returns:
pd.DataFrame: The generated synthetic data.
"""
augmentation_counts = calculate_fair_aug_sample_size(
X_train.dataframe(),
X_train.get_fairness_column(),
rule,
ad_hoc_augment_vals=ad_hoc_augment_vals,
)
if not strict:
# set count equal to the total number of records required according to calculate_fair_aug_sample_size
count = sum(augmentation_counts.values())
cond = pd.Series(
np.repeat(
list(augmentation_counts.keys()), list(augmentation_counts.values())
)
)
syn_data = augment_generator.generate(
count=count,
cond=cond,
constraints=synthetic_constraints,
**generate_kwargs,
).dataframe()
else:
syn_data_list = []
for fairness_value, count in augmentation_counts.items():
if count > 0:
constraints = Constraints(
rules=[(X_train.get_fairness_column(), "==", fairness_value)]
)
syn_data_list.append(
augment_generator.generate(
count=count, constraints=constraints
).dataframe()
)
syn_data = pd.concat(syn_data_list)
return syn_data


@validate_arguments(config=dict(arbitrary_types_allowed=True))
def augment_data(
X_train: DataLoader,
augment_generator: Any,
strict: bool = False,
rule: Literal["equal", "log", "ad-hoc"] = "equal",
ad_hoc_augment_vals: Optional[
Dict[Any, int]
] = None, # Only required for rule == "ad-hoc"
synthetic_constraints: Optional[Constraints] = None,
**generate_kwargs: Any,
) -> DataLoader:
"""Augment the real data with generated synthetic data

Args:
X (DataLoader): The ground truth DataLoader to augment with synthetic data.
model_name (str): The name of the synthetic model plugin to use to generate the synthetic data.
prefix (str, optional): prefix (str): The prefix for the saved synthetic data generation model filename. Defaults to "fairness.conditional_augmentation".
strict (bool, optional): Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False.
rule (Literal["equal", "log", "ad-hoc"): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal".
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to {}.

Returns:
Tuple[np.ndarray, np.ndarray]: The augmented dataset and labels.
"""
syn_data = _generate_synthetic_data(
X_train,
augment_generator,
strict=strict,
rule=rule,
ad_hoc_augment_vals=ad_hoc_augment_vals,
synthetic_constraints=synthetic_constraints,
**generate_kwargs,
)

augmented_data_loader = copy(X_train)
augmented_data_loader.data = pd.concat(
[
X_train.data,
syn_data,
]
)

return augmented_data_loader
Loading