-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Augmentation benchmark #150
Merged
Merged
Changes from 16 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
56c97c6
Add new benchmark code
robsdavis e5ffa12
Merge branch 'main' into augmentation-benchmark
robsdavis ce76de1
Merge main into branch
robsdavis 32e049b
Merge branch 'main' into augmentation-benchmark
robsdavis 6c9a0d7
Augmentation benchmark added
robsdavis de042de
Clean up
robsdavis c2736ba
Cleaning up
robsdavis af63f20
Merge branch 'main' of https://github.com/vanderschaarlab/synthcity i…
robsdavis 3459334
Remove unnecessary tutorial file
robsdavis 21657a3
Clean up
robsdavis fdca5a7
clean up
robsdavis b79fe84
Debug test and clean up
robsdavis 06d90f2
Added new tests for augmentation benchmark
robsdavis eec447a
Added new metric api tests for augmentation
robsdavis 2eb7f06
clean up
robsdavis 26aa26c
clean up
robsdavis d34e995
version bumped and clean up
robsdavis 7d6a667
clean up docstrings
robsdavis File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# stdlib | ||
import math | ||
from copy import copy | ||
from typing import Any, Dict, Optional | ||
|
||
# third party | ||
import numpy as np | ||
import pandas as pd | ||
from pydantic import validate_arguments | ||
from typing_extensions import Literal | ||
|
||
# synthcity absolute | ||
from synthcity.plugins.core.constraints import Constraints | ||
from synthcity.plugins.core.dataloader import DataLoader | ||
|
||
|
||
def calculate_fair_aug_sample_size( | ||
X_train: pd.DataFrame, | ||
fairness_column: Optional[str], # a categorical column of K levels | ||
rule: Literal[ | ||
"equal", "log", "ad-hoc" | ||
], # TODO: Confirm are there any more methods to include | ||
ad_hoc_augment_vals: Optional[ | ||
Dict[Any, int] | ||
] = None, # Only required for rule == "ad-hoc" | ||
) -> Dict: | ||
"""Calculate how many samples to augment. | ||
|
||
Args: | ||
X_train (pd.DataFrame): The real dataset to be augmented. | ||
fairness_column (str): The column name of the column to test the fairness of a downstream model with respect to. | ||
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. If using rule="ad-hoc" this function returns ad_hoc_augment_vals, otherwise this parameter is ignored. Defaults to {}. | ||
|
||
Returns: | ||
Dict: A dictionary containing the number of each class to augment the real data with. | ||
""" | ||
|
||
# the majority class is unchanged | ||
if rule == "equal": | ||
# number of sample will be the same for each value in the fairness column after augmentation | ||
# N_aug(i) = N_ang(j) for all i and j in value in the fairness column | ||
fairness_col_counts = X_train[fairness_column].value_counts() | ||
majority_size = fairness_col_counts.max() | ||
augmentation_counts = { | ||
fair_col_val: (majority_size - fairness_col_counts.loc[fair_col_val]) | ||
for fair_col_val in fairness_col_counts.index | ||
} | ||
elif rule == "log": | ||
# number of samples in aug data will be proportional to the log frequency in the real data. | ||
# Note: taking the log makes the distribution more even. | ||
# N_aug(i) is proportional to log(N_real(i)) | ||
fairness_col_counts = X_train[fairness_column].value_counts() | ||
majority_size = fairness_col_counts.max() | ||
log_coefficient = majority_size / math.log(majority_size) | ||
|
||
augmentation_counts = { | ||
fair_col_val: ( | ||
majority_size - round(math.log(fair_col_count) * log_coefficient) | ||
) | ||
for fair_col_val, fair_col_count in fairness_col_counts.items() | ||
} | ||
elif rule == "ad-hoc": | ||
# use user-specified values to augment | ||
if not ad_hoc_augment_vals: | ||
raise ValueError( | ||
"When augmenting with an `ad-hoc` method, ad_hoc_augment_vals must be a dictionary, where the dictionary keys are the values of the fairness_column and the dictionary values are the number of records to augment." | ||
) | ||
else: | ||
if not set(ad_hoc_augment_vals.keys()).issubset( | ||
set(X_train[fairness_column].values) | ||
): | ||
print(set(X_train[fairness_column].values)) | ||
print(set(ad_hoc_augment_vals.keys())) | ||
raise ValueError( | ||
"ad_hoc_augment_vals must be a dictionary, where the dictionary keys are the values of the fairness_column and the dictionary values are the number of records to augment." | ||
) | ||
elif set(X_train[fairness_column].values) != set( | ||
ad_hoc_augment_vals.keys() | ||
): | ||
ad_hoc_augment_vals = { | ||
k: v | ||
for k, v in ad_hoc_augment_vals.items() | ||
if k in set(X_train[fairness_column].values) | ||
} | ||
|
||
augmentation_counts = ad_hoc_augment_vals | ||
|
||
return augmentation_counts | ||
|
||
|
||
@validate_arguments(config=dict(arbitrary_types_allowed=True)) | ||
def _generate_synthetic_data( | ||
X_train: DataLoader, | ||
augment_generator: Any, | ||
strict: bool = True, | ||
rule: Literal["equal", "log", "ad-hoc"] = "equal", | ||
ad_hoc_augment_vals: Optional[ | ||
Dict[Any, int] | ||
] = None, # Only required for rule == "ad-hoc" | ||
synthetic_constraints: Optional[Constraints] = None, | ||
**generate_kwargs: Any, | ||
) -> pd.DataFrame: | ||
|
||
"""Generates synthetic data | ||
|
||
Args: | ||
X_train (pd.DataFrame): The dataset used to train the downstream model. | ||
y_train (Union[pd.Series, pd.DataFrame]): The data labels for `X_train`. This is used to train the downstream model. | ||
fairness_column (str): The column name of the column to test the fairness of a downstream model with respect to. | ||
target_column (str): The column name of the label column. | ||
syn_model_name (str): The name of the synthetic model plugin to use to generate the synthetic data. | ||
strict (bool, optional): Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False. | ||
rule (Literal["equal", "log", "ad-hoc"): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal". | ||
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to {}. | ||
random_state (int, optional): The random state to seed the synthetic data generation. Defaults to 42. | ||
|
||
Returns: | ||
pd.DataFrame: The generated synthetic data. | ||
""" | ||
augmentation_counts = calculate_fair_aug_sample_size( | ||
X_train.dataframe(), | ||
X_train.get_fairness_column(), | ||
rule, | ||
ad_hoc_augment_vals=ad_hoc_augment_vals, | ||
) | ||
if not strict: | ||
# set count equal to the total number of records required according to calculate_fair_aug_sample_size | ||
count = sum(augmentation_counts.values()) | ||
cond = pd.Series( | ||
np.repeat( | ||
list(augmentation_counts.keys()), list(augmentation_counts.values()) | ||
) | ||
) | ||
syn_data = augment_generator.generate( | ||
count=count, | ||
cond=cond, | ||
constraints=synthetic_constraints, | ||
**generate_kwargs, | ||
).dataframe() | ||
else: | ||
syn_data_list = [] | ||
for fairness_value, count in augmentation_counts.items(): | ||
if count > 0: | ||
constraints = Constraints( | ||
rules=[(X_train.get_fairness_column(), "==", fairness_value)] | ||
) | ||
syn_data_list.append( | ||
augment_generator.generate( | ||
count=count, constraints=constraints | ||
).dataframe() | ||
) | ||
syn_data = pd.concat(syn_data_list) | ||
return syn_data | ||
|
||
|
||
@validate_arguments(config=dict(arbitrary_types_allowed=True)) | ||
def augment_data( | ||
X_train: DataLoader, | ||
augment_generator: Any, | ||
strict: bool = False, | ||
rule: Literal["equal", "log", "ad-hoc"] = "equal", | ||
ad_hoc_augment_vals: Optional[ | ||
Dict[Any, int] | ||
] = None, # Only required for rule == "ad-hoc" | ||
synthetic_constraints: Optional[Constraints] = None, | ||
**generate_kwargs: Any, | ||
) -> DataLoader: | ||
"""Augment the real data with generated synthetic data | ||
|
||
Args: | ||
X (DataLoader): The ground truth DataLoader to augment with synthetic data. | ||
model_name (str): The name of the synthetic model plugin to use to generate the synthetic data. | ||
prefix (str, optional): prefix (str): The prefix for the saved synthetic data generation model filename. Defaults to "fairness.conditional_augmentation". | ||
strict (bool, optional): Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False. | ||
rule (Literal["equal", "log", "ad-hoc"): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal". | ||
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to {}. | ||
|
||
Returns: | ||
Tuple[np.ndarray, np.ndarray]: The augmented dataset and labels. | ||
""" | ||
syn_data = _generate_synthetic_data( | ||
X_train, | ||
augment_generator, | ||
strict=strict, | ||
rule=rule, | ||
ad_hoc_augment_vals=ad_hoc_augment_vals, | ||
synthetic_constraints=synthetic_constraints, | ||
**generate_kwargs, | ||
) | ||
|
||
augmented_data_loader = copy(X_train) | ||
augmented_data_loader.data = pd.concat( | ||
[ | ||
X_train.data, | ||
syn_data, | ||
] | ||
) | ||
|
||
return augmented_data_loader |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't leave prints in the code. use
log
if the logs are needed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted.