-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
134 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
"""Test for the image input handler.""" | ||
import os | ||
import shutil | ||
|
||
import pytest | ||
import yaml | ||
from dotmap import DotMap | ||
|
||
from leakpro import LeakPro | ||
from leakpro.tests.constants import STORAGE_PATH, get_audit_config | ||
from leakpro.tests.input_handler.image_input_handler import ImageInputHandler | ||
from leakpro.tests.input_handler.image_utils import setup_image_test | ||
from leakpro.tests.input_handler.tabular_input_handler import TabularInputHandler | ||
from leakpro.tests.input_handler.tabular_utils import setup_tabular_test | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def manage_storage_directory(): | ||
"""Fixture to create and remove the storage directory.""" | ||
|
||
# Setup: Create the folder at the start of the test session | ||
os.makedirs(STORAGE_PATH, exist_ok=True) | ||
|
||
# Yield control back to the test session | ||
yield | ||
|
||
# Teardown: Remove the folder and its contents at the end of the session | ||
if os.path.exists(STORAGE_PATH): | ||
shutil.rmtree(STORAGE_PATH) | ||
|
||
@pytest.fixture | ||
def image_handler(manage_storage_directory) -> ImageInputHandler: | ||
"""Fixture for the image input handler to be shared between many tests.""" | ||
|
||
config = DotMap() | ||
config.target = setup_image_test() | ||
config.audit = get_audit_config() | ||
config.audit.modality = "image" | ||
#save config to file | ||
config_path = f"{STORAGE_PATH}/image_test_config.yaml" | ||
with open(config_path, "w") as f: | ||
yaml.dump(config.toDict(), f) | ||
|
||
leakpro = LeakPro(ImageInputHandler, config_path) | ||
handler = leakpro.handler | ||
handler.configs = DotMap(handler.configs) | ||
|
||
# Yield control back to the test session | ||
return handler | ||
|
||
@pytest.fixture | ||
def tabular_handler(manage_storage_directory) -> TabularInputHandler: | ||
"""Fixture for the image input handler to be shared between many tests.""" | ||
|
||
config = DotMap() | ||
config.target = setup_tabular_test() | ||
config.audit = get_audit_config() | ||
config.audit.modality = "tabular" | ||
#save config to file | ||
config_path = f"{STORAGE_PATH}/tabular_test_config.yaml" | ||
with open(config_path, "w") as f: | ||
yaml.dump(config.toDict(), f) | ||
|
||
leakpro = LeakPro(TabularInputHandler, config_path) | ||
handler = leakpro.handler | ||
handler.configs = DotMap(handler.configs) | ||
|
||
# Yield control back to the test session | ||
return handler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from dotmap import DotMap | ||
|
||
STORAGE_PATH = "./leakpro/tests/tmp" | ||
|
||
# User input handler for images | ||
|
||
|
||
def get_image_handler_config(): | ||
parameters = DotMap() | ||
parameters.target_folder = "./leakpro/tests/tmp/image" | ||
parameters.epochs = 10 | ||
parameters.batch_size = 64 | ||
parameters.learning_rate = 0.001 | ||
parameters.optimizer = "sgd" | ||
parameters.loss = "crossentropyloss" | ||
parameters.data_points = 130 | ||
parameters.train_data_points = 20 | ||
parameters.test_data_points = 20 | ||
parameters.img_size = (3, 32, 32) | ||
parameters.num_classes = 13 | ||
parameters.images_per_class = parameters.data_points // parameters.num_classes | ||
return parameters | ||
|
||
def get_tabular_handler_config(): | ||
parameters = DotMap() | ||
parameters.target_folder = "./leakpro/tests/tmp/tabular" | ||
parameters.epochs = 10 | ||
parameters.batch_size = 64 | ||
parameters.learning_rate = 0.001 | ||
parameters.optimizer = "sgd" | ||
parameters.loss = "BCEWithLogitsLoss" | ||
parameters.data_points = 500 | ||
parameters.train_data_points = 200 | ||
parameters.test_data_points = 200 | ||
parameters.num_classes = 1 | ||
parameters.n_continuous = 10 | ||
parameters.n_categorical = 5 | ||
return parameters | ||
|
||
def get_audit_config(): | ||
#audit configuration | ||
audit_config = DotMap() | ||
audit_config.output_dir = STORAGE_PATH | ||
audit_config.attack_type = "mia" | ||
# Lira parameters | ||
audit_config.attack_list.lira.training_data_fraction = 0.1 | ||
audit_config.attack_list.lira.num_shadow_models = 3 | ||
audit_config.attack_list.lira.online = False | ||
audit_config.attack_list.lira.fixed_variance = True | ||
return audit_config | ||
|
||
|
||
|
||
# Shadow model configuration for images | ||
def get_shadow_model_config(): | ||
shadow_model_config = DotMap() | ||
shadow_model_config.module_path = "./leakpro/tests/input_handler/image_utils.py" | ||
shadow_model_config.model_class = "ConvNet" | ||
shadow_model_config.batch_size = 32 | ||
shadow_model_config.epochs = 1 | ||
shadow_model_config.optimizer = {"name": "sgd", "lr": 0.001} | ||
shadow_model_config.loss = {"name": "crossentropyloss"} | ||
return shadow_model_config | ||
|
||
|