-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Augmentation benchmark #150
Conversation
…nto augmentation-benchmark
@@ -65,7 +65,7 @@ def test_plugin_hyperparams(test_plugin: Plugin) -> None: | |||
], | |||
) | |||
def test_plugin_fit(tte_strategy: str) -> None: | |||
test_plugin = plugin(tte_strategy=tte_strategy, device="cpu", **plugins_args) | |||
test_plugin = plugin(tte_strategy=tte_strategy, device="zz", **plugins_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does zz
mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo. Now fixed
src/synthcity/metrics/eval.py
Outdated
use_cache=use_cache, | ||
), | ||
X_gt.sample(eval_cnt), | ||
X_augmented.sample(eval_cnt), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is sample(eval_cnt)
relevant for the augmented dataset? The augmented dataset will be larger than X_gt
everytime, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point. I'll remove the sample call in the next push
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
Some changes are still needed, after that this can be merged
src/synthcity/benchmark/utils.py
Outdated
if not set(ad_hoc_augment_vals.keys()).issubset( | ||
set(X_train[fairness_column].values) | ||
): | ||
print(set(X_train[fairness_column].values)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't leave prints in the code. use log
if the logs are needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted.
@@ -290,8 +295,6 @@ class GenericDataLoader(DataLoader): | |||
>>> from synthcity.plugins.core.dataloader import GenericDataLoader | |||
>>> X, y = load_diabetes(return_X_y=True, as_frame=True) | |||
>>> X["target"] = y | |||
>>> # Important note: preprocessing data with OneHotEncoder or StandardScaler is not needed or recommended. | |||
>>> # Synthcity handles feature encoding and standardization internally. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you remove these lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accident, re-instated.
tests/metrics/test_api.py
Outdated
|
||
|
||
@pytest.mark.parametrize( | ||
"fairness_column, rule, strict, add_hoc_vals", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean ad_hoc
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, fixed
Description
Added an augmentation benchmark pipeline.
closes #136
Affected Dependencies
None
How has this been tested?
Checklist