-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathrun_PD_FL.py
114 lines (89 loc) · 4.53 KB
/
run_PD_FL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import time
import torch
from utils.data_utils import load_dataset, make_dataloader, split_dataset, create_imbalance, get_auxiliary_data, make_transforms
from utils.model_utils import make_model
from utils.logger_utils import Logger
from utils.test_utils import make_evaluate_fn, make_monitor_fn
from core.fed_avg import FEDAVG
from core.fed_pd import FEDPD
from core.imbalance_fl import ImbalanceFL
from core.ratio_loss_fl import RatioLossFL
from torch.utils.tensorboard import SummaryWriter
from config import make_parser
import os
import json
FEDERATED_LEARNERS = {
'fed-avg': FEDAVG,
'fed-pd': FEDPD
}
PD_FEDERATED_LEARNERS = {
'imbalance-fl': ImbalanceFL,
'ratioloss-fl': RatioLossFL
}
def main():
# 1. load the configurations
args = make_parser().parse_args()
print(args)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
loss = torch.nn.functional.cross_entropy
level = args.homo_ratio if args.heterogeneity == "mix" else args.dir_level
experiment_setup = f"Class_Imbalance_FL_{args.formulation}_{args.heterogeneity}_{level}_{args.n_workers}_{args.n_workers_per_round}_{args.dataset}_{args.n_minority}_{args.reduce_to_ratio}_{args.model}"
hyperparameter_setup = f"{args.learner}_{args.global_lr}_{args.local_lr}_{args.lambda_lr}_{args.n_p_steps}_{args.tolerance_epsilon}_{args.client_step_per_epoch}_{args.local_epoch}"
if args.learner == "fed-pd":
hyperparameter_setup += f"_{args.eta}_{args.fed_pd_dual_lr}"
args.save_dir = 'output/%s/%s' %(experiment_setup, hyperparameter_setup)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
with open(args.save_dir+'/config.json', 'w') as f:
json.dump(vars(args), f)
# 2. prepare the data set
dataset_train, dataset_test, n_classes, n_channels, img_size = load_dataset(args.dataset)
if args.imbalance:
assert (args.n_minority < n_classes)
if args.n_minority == 1:
reduce_classes = (0,)
elif args.n_minority == 3:
reduce_classes = (0, 2, 4)
elif args.n_minority == 5:
reduce_classes = (0, 2, 4, 6, 8)
else:
raise RuntimeError
dataset_train = create_imbalance(dataset_train, reduce_classes=reduce_classes,
reduce_to_ratio=args.reduce_to_ratio)
transforms = make_transforms(args, args.dataset, train=True) # transforms for data augmentation and normalization
local_datasets = split_dataset(args, dataset_train, transforms)
local_dataloaders = [make_dataloader(args, "train", local_dataset) for local_dataset in local_datasets]
transforms_test = make_transforms(args, args.dataset, train=False)
dataset_test.transform = transforms_test
test_dataloader = make_dataloader(args, "test", dataset_test)
model = make_model(args, n_classes, n_channels, device, img_size)
test_fn_accuracy = make_evaluate_fn(test_dataloader, device, eval_type='accuracy', n_classes=n_classes, loss_fn=loss)
test_fn_class_wise_accuracy = make_evaluate_fn(test_dataloader, device, eval_type='class_wise_accuracy', n_classes=n_classes)
statistics_monitor_fn = make_monitor_fn()
# 3. prepare logger
tb_file = args.save_dir+f'/{time.time()}'
print(f"writing to {tb_file}")
writer = SummaryWriter(tb_file)
logger_accuracy = Logger(writer, test_fn_accuracy, test_metric='accuracy')
logger_class_wise_accuracy = Logger(writer, test_fn_class_wise_accuracy, test_metric='class_wise_accuracy')
logger_monitor = Logger(writer, statistics_monitor_fn, test_metric='model_monitor')
loggers = [logger_accuracy, logger_class_wise_accuracy, logger_monitor]
# 4. run PD FL
make_pd_fed_learner = PD_FEDERATED_LEARNERS[args.formulation]
make_fed_learner = FEDERATED_LEARNERS[args.learner]
fed_learner = make_fed_learner(init_model=model,
client_dataloaders=local_dataloaders,
loss=loss,
loggers=None,
config=args,
device=device
)
n_aux = 5
auxiliary_data = get_auxiliary_data(args, transforms_test, dataset_train, n_classes,
n_aux) if args.formulation == "ratioloss-fl" else None
pd_fed_learner = make_pd_fed_learner(fed_learner, args, loggers, auxiliary_data)
pd_fed_learner.fit()
# # 4. save the model
# save_model(args, fed_learner)
if __name__ == '__main__':
main()