-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathconfig.py
93 lines (76 loc) · 5.36 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
def make_parser():
parser = argparse.ArgumentParser()
# general configurations
parser.add_argument('--n_pd_rounds', type=int, default=5000, help='total dual rounds for PDFL')
parser.add_argument('--n_global_rounds', type=int, default=5000, help='total communication rounds for FL')
parser.add_argument('--test_batch_size', type=int, default=200)
parser.add_argument('--use_ray', action='store_true')
parser.add_argument('--load_model', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--test_metric', type=str, choices=['accuracy', 'class_wise_accuracy'],
default='class_wise_accuracy', help='what to report in tensorboard')
parser.add_argument('--eval_freq', type=int, default=1, help='how often the test loss should be checked')
parser.add_argument('--weighted', action='store_true', help='allow clients to have different weights initially')
parser.add_argument('--loss_fn', type=str, choices=['focal-loss', 'cross-entropy-loss'],
default='cross-entropy-loss', help='loss functional')
# tricks for NN training
parser.add_argument('--no_data_augmentation', action='store_true', help='disable the data augmentation')
parser.add_argument('--use_gradient_clip', action='store_true')
parser.add_argument('--gradient_clip_constant', type=float, default=5.)
parser.add_argument('--weight_decay', type=float, default=1e-3)
# Experiment setup
parser.add_argument('--heterogeneity', type=str, choices=['mix', 'dir'], default='mix',
help='Type of heterogeneity, mix or dir(dirichlet)')
parser.add_argument('--homo_ratio', type=float, default=1.)
parser.add_argument('--dir_level', type=float, default=.3, help='hyperparameter of the Dirichlet distribution')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'mnist', 'fashion-mnist', 'emnist-letter', 'emnist-digit', 'shakespeare'],
default='cifar10', help='dataset (and the corresponding task), now only support cifar10')
parser.add_argument('--dense_hid_dims', type=str, default='384-192')
parser.add_argument('--conv_hid_dims', type=str, default='64-64')
parser.add_argument('--model', type=str, choices=['mlp', 'convnet', 'resnet'], default='convnet')
parser.add_argument('--learner', type=str, choices=['fed-avg', 'fed-pd', 'scaffold', 'ffgb-d', 'fedavg-d'], default='fed-pd')
parser.add_argument('--formulation', type=str, choices=['imbalance-fl', 'ratioloss-fl', 'GHMC_loss'],
default='imbalance-fl', help='formulation for handling class imbalance problem')
parser.add_argument('--n_workers', type=int, default=50)
parser.add_argument('--n_workers_per_round', type=int, default=5)
parser.add_argument('--l2_reg', type=float, default=-1.)
parser.add_argument('--imbalance', action='store_true', help='create imbalance among classes')
parser.add_argument('--n_minority', type=int, default=1, help='number of minority classes')
parser.add_argument('--reduce_to_ratio', type=float, default=1.)
# General hyperparameters
parser.add_argument('--local_lr', type=float, default=0.1)
parser.add_argument('--global_lr', type=float, default=1.)
parser.add_argument('--local_epoch', type=int, default=5)
parser.add_argument('--client_step_per_epoch', type=int, default=5)
# Hyperparameters for the formulation "imbalance-fl"
parser.add_argument('--lambda_lr', type=float, default=1)
parser.add_argument('--tolerance_epsilon', type=float, default=1.)
parser.add_argument('--n_p_steps', type=int, default=5, help="primal steps per dual step in PDFL")
# Hyperparameters for fed-pd
parser.add_argument('--eta', type=float, default=10)
parser.add_argument('--fed_pd_dual_lr', type=float, default=1)
# General hyperparameters for Functional Federated Learning
parser.add_argument('--local_dataloader_batch_size', type=int, default=50)
parser.add_argument('--distill_dataloader_batch_size', type=int, default=50)
parser.add_argument('--dataset_distill', type=str,
choices=['cifar10', 'cifar100', 'mnist', 'fashion-mnist', 'emnist-letter', 'emnist-digit',
'shakespeare'],
default='cifar10')
# Hyperparameters for FFGB-D
parser.add_argument('--functional_lr', type=float, default=10.)
parser.add_argument('--f_l2_reg', type=float, default=5e-3)
parser.add_argument('--local_steps', type=int, default=1)
parser.add_argument('--weak_learner_epoch', type=int, default=120)
parser.add_argument('--weak_learner_lr', type=float, default=1e-3)
parser.add_argument('--weak_learner_weight_decay', type=float, default=0)
# Hyperparameters for FEDAVG-D
parser.add_argument('--fedavg_d_local_lr', type=float, default=.01)
parser.add_argument('--fedavg_d_local_epoch', type=int, default=50)
parser.add_argument('--fedavg_d_weight_decay', type=float, default=1e-3)
# distill oracle
parser.add_argument('--distill_oracle', type=str, choices=["kl", "l2"], default="l2")
parser.add_argument('--distill_oracle_epoch', type=int, default=10)
parser.add_argument('--distill_oracle_lr', type=float, default=1e-3)
parser.add_argument('--distill_oracle_weight_decay', type=float, default=.001)
return parser