keywords:
federated-learning
,asynchronous
,synchronous
,semi-asynchronous
,personalized
Table of Contents
The initial intention of this project is to build an asynchronous federated learning framework and conduct experiments on it during my undergraduate thesis.
However, when I tried to search for related open-source projects on GitHub, I found that the field of asynchronous federated learning is quite closed-source, with almost no open-source projects available. Additionally, mainstream frameworks also lack compatibility with asynchronous FL and only support synchronous FL. Thus, this project was born.
The master branch is the main branch with the latest code, but some of the commits are dirty commits and not guaranteed to run properly. It is recommended to use tagged versions for better stability.
The checkout branch retains the functionality of adding clients to the system during the training process, which has been removed in the main branch. The checkout branch is not actively maintained and only supports synchronous and asynchronous FL.
python3.8 + pytorch + macos
It has been validated on Linux.
It supports single GPU and Multi-GPU.
You can run python main.py
(the main file in the fl directory) directly. The program will automatically read the config.json
file in the root directory and store the results in the specified path under results
, along with the configuration file.
You can also specify the configuration file by python main.py ../../config.json
. Please note that the path of config.json
is relative to the main.py
.
The config
folder in the root directory provides some algorithm configuration files proposed in papers. The following algorithm implementations are currently available:
FedAvg
FedAsync
FedProx
FedAT
FedLC
FedDL
M-Step AsyncFL
Now you can directly pull and run a Docker image, the command is as follows:
docker pull desperadoccy/async-fl
docker run -it async-fl config/FedAvg-config.json
Similarly, it supports passing a config file path as a parameter. You can also build the Docker image yourself.
cd docker
docker build -t async-fl .
docker run -it async-fl config/FedAvg-config.json
- Asynchronous Federated Learning
- Support model and dataset replacement
- Support scheduling algorithm replacement
- Support aggregation algorithm replacement
- Support loss function replacement
- Support client replacement
- Synchronous federated learning
- Semi-asynchronous federated learning
- Provide test loss information
- Custom label heterogeneity
- Custom data heterogeneity
- Support Dirichlet distribution
- wandb visualization
- Support for leaf-related datasets
- Support for multiple GPUs
- Docker deployment
Project Directory
.
├── config Common algorithm configuration files
│ ├── FedAT-config.json
│ ├── FedAsync-config.json
│ ├── FedAvg-config.json
│ ├── FedDL-config.json
│ ├── FedLC-config.json
│ ├── FedProx-config.json
│ ├── MSTEPAsync-config.json
│ ├── config.json
│ └── model_config
│ ├── CIFAR10-config.json
│ ├── ResNet18-config.json
│ └── ResNet50-config.json
├── config.json
├── config_semi.json
├── config_semi_test.json
├── config_sync.json
├── config_sync_test.json
├── config_test.json
├── doc
│ ├── params.docx
│ ├── pic
│ │ ├── fedsemi.png
│ │ ├── framework.png
│ │ └── header.png
│ ├── readme-zh.md
│ └── 参数.docx
├── docker
│ └── Dockerfile
├── license
├── readme.md
├── requirements.txt
└── src
├── checker checker implementation
│ ├── AllChecker.py
│ ├── CheckerCaller.py
│ ├── SyncChecker.py
│ └── __init__.py
├── client client implementation
│ ├── ActiveClient.py
│ ├── Client.py
│ ├── DLClient.py
│ ├── NormalClient.py
│ ├── ProxClient.py
│ ├── SemiClient.py
│ ├── TestClient.py
│ └── __init__.py
├── clientmanager client manager implementation
│ ├── BaseClientManager.py
│ ├── NormalClientManager.py
│ └── __init__.py
├── compressor compressor algorithm class
│ ├── QSGD.py
│ └── __init__.py
├── data
├── dataset
│ ├── CIFAR10.py
│ ├── FashionMNIST.py
│ ├── MNIST.py
│ └── __init__.py
├── exception
│ ├── ClientSumError.py
│ └── __init__.py
├── fl wandb running directory
│ ├── __init__.py
│ ├── main.py
│ └── wandb
├── group group algorithm class
│ ├── AbstractGroup.py
│ ├── DelayGroup.py
│ ├── GroupCaller.py
│ ├── OneGroup.py
│ └── __init__.py
├── groupmanager group manager implementation
│ ├── BaseGroupManager.py
│ ├── NormalGroupManager.py
│ └── __init__.py
├── loss loss algorithm class
│ ├── FedLC.py
│ ├── LossFactory.py
│ └── __init__.py
├── model
│ ├── CNN.py
│ └── __init__.py
├── numgenerator num generator algorithm class
│ ├── AbstractNumGenerator.py
│ ├── NumGeneratorFactory.py
│ ├── StaticNumGenerator.py
│ └── __init__.py
├── queuemanager queuemanager implementation
│ ├── AbstractQueueManager.py
│ ├── BaseQueueManger.py
│ ├── QueueListManager.py
│ ├── SingleQueueManager.py
│ └── __init__.py
├── receiver receiver implementation
│ ├── MultiQueueReceiver.py
│ ├── NoneReceiver.py
│ ├── NormalReceiver.py
│ ├── ReceiverCaller.py
│ └── __init__.py
├── results
├── schedule scheduling algorithm class
│ ├── AbstractSchedule.py
│ ├── FullSchedule.py
│ ├── NoSchedule.py
│ ├── RandomSchedule.py
│ ├── RoundRobin.py
│ ├── ScheduleCaller.py
│ └── __init__.py
├── scheduler scheduler implementation
│ ├── AsyncScheduler.py
│ ├── BaseScheduler.py
│ ├── SemiAsyncScheduler.py
│ ├── SyncScheduler.py
│ └── __init__.py
├── server server implementation
│ ├── AsyncServer.py
│ ├── BaseServer.py
│ ├── SemiAsyncServer.py
│ ├── SyncServer.py
│ └── __init__.py
├── test for test
│ ├── __init__.py
│ ├── test.ipynb
│ └── test.py
├── update update algorithm class
│ ├── AbstractUpdate.py
│ ├── AsyncAvg.py
│ ├── FedAT.py
│ ├── FedAsync.py
│ ├── FedAvg.py
│ ├── FedDL.py
│ ├── StepAsyncAvg.py
│ ├── UpdateCaller.py
│ └── __init__.py
├── updater updater implementation
│ ├── AsyncUpdater.py
│ ├── BaseUpdater.py
│ ├── SemiAsyncUpdater.py
│ ├── SyncUpdater.py
│ └── __init__.py
└── utils
├── ConfigManager.py
├── GlobalVarGetter.py
├── IID.py
├── JsonTool.py
├── ModelTraining.py
├── ModuleFindTool.py
├── Plot.py
├── ProcessTool.py
├── Queue.py
├── Random.py
├── Time.py
├── Tools.py
└── __init__.py
The "Time" file under the "utils" package is an implementation of a multi-threaded time acquisition class, and the "Queue" file is an implementation of related functionalities for the "queue" module, as some functionalities of the "queue" module are not yet implemented on macOS.
The receiver in synchronous and semi-asynchronous federated learning is used to check whether the updates received during the current global iteration meet the conditions set, such as whether all designated clients have uploaded their updates. If the conditions are met, the updater process will be triggered to perform global aggregation.
In synchronous and semi-asynchronous federated learning, after a client completes its training, it will upload its weights to the uploader class, which will determine whether the update meets the upload criteria based on its own logic, and decide whether to accept or discard the update.
Parameter explanation
parameters |
type |
explanations |
||||
wandb |
enabled |
bool |
whether to enable wandb |
|||
project |
string |
project name |
||||
name |
string |
the name of this run |
||||
global |
use_file_system |
bool |
whether to enable the file system as the torch multi-thread sharing strategy |
|||
multi_gpu |
bool |
whether to enable multi-GPU, detailed explanation |
||||
experiment |
string |
the name of this run |
||||
stale |
||||||
dataset |
path |
string |
the path of the dataset |
|||
params |
dict |
required parameters |
||||
iid |
||||||
client_num |
int |
client num |
||||
server |
path |
string |
the path of server |
|||
epochs |
int |
global epoch |
||||
model |
path |
string |
the path of the model |
|||
params |
dict |
required parameters |
||||
scheduler |
path |
string |
the path of the scheduler |
|||
schedule |
path |
string |
the path of the schedule |
|||
params |
dict |
required parameters |
||||
other_params |
* |
other parameters |
||||
updater |
path |
string |
the path of the updater |
|||
update |
path |
string |
the path of the update |
|||
params |
dict |
required parameters |
||||
loss |
||||||
num_generator |
||||||
group |
path |
string |
the path of the updater |
|||
params |
dict |
required parameters |
||||
client_manager |
path |
string |
the path of the client manager |
|||
group_manager |
path |
string |
the path of the group manager |
|||
group_method |
path |
string |
the path of the group method |
|||
params |
dict |
required parameters |
||||
queue_manager |
path |
string |
the path of the queue manager |
|||
receiver |
path |
string |
the path of the receiver |
|||
params |
dict |
required parameters |
||||
checker |
path |
string |
the path of the checker |
|||
params |
dict |
required parameters |
||||
client |
path |
string |
the path of the client |
|||
epochs |
int |
local epoch |
||||
batch_size |
int |
batch |
||||
model |
path |
string |
the path of the model |
|||
params |
dict |
required parameters |
||||
loss |
||||||
mu |
float |
proximal term’s coefficient |
||||
optimizer |
path |
string |
the path of the optimizer |
|||
params |
dict |
required parameters |
||||
other_params |
* |
other parameters |
||||
To allow clients/servers to call your own algorithms or implementation classes (note: all algorithm implementations must be in class form), the following steps are required:
- Add your own implementation to the corresponding location (dataset, model, schedule, update, client, loss)
- Import the class in the
__init__.py
file of the corresponding package, for examplefrom model import CNN
- Declare in the configuration file,
model_path
corresponds to the path where the new algorithm is located. checker
,group
,receiver
,schedule
, andupdate
modules need to be supplemented with invocation methods in theCaller
class.loss
andnumgenerator
modules need to be supplemented with invocation methods in thefactory
class.
In addition, parameters that the algorithm needs to use can be declared in the params
configuration item.
Now the model
, optim
, and loss
modules support the introduction of built-in implementation classes such as torch
, for example:
"model": {
"path": "torchvision.models.resnet18",
"params": {
"pretrained": true,
"num_classes": 10
}
}
The loss function is now generated and created by the LossFactory
class. You can choose to use built-in algorithms from Torch
or implement your own.
The loss configuration supports three settings. The first option is using a string format commonly used in the configuration file:
"loss": "torch.nn.functional.cross_entropy"
In this case, the program will directly generate a loss function using the functional
approach.
The second option is to generate an object-based
loss:
"loss": {
"path": "loss.myloss.MyLoss",
"params": {}
}
Here, you specify the path to your custom loss class and provide any necessary parameters in the params field.
The third option is to generate a loss based on the type:
"loss": {
"type": "func",
"path": "loss.myloss.MyLoss",
"params": {}
}
With this option, you also provide the type field as "func", and the rest of the process is similar to the object-based approach.
stale
has three settings, one of which is mentioned in the above configuration file.
"stale": {
"step": 5,
"shuffle": true,
"list": [10, 10, 10, 5, 5, 5, 5]
}
The program will generate a string of random integers based on the provided step
and list
. For example, in the code above, the program will generate 10 zeros, 10 (0, 5), and 10 [5, 10), and shuffle them if shuffle is set to true. Finally, the random string is assigned to each client, and the client sleeps according to the corresponding number of seconds after each round of training. When storing the JSON file to the experimental results, this setting will be automatically converted to the third setting.
The second option is to set it to false, in which case the program will set the delay for each client to 0.
"stale": false
The third option is a list of random integers, and the program will directly assign the delay settings from the list to the clients.
"stale": [1, 2, 3, 1, 4]
When iid
is set to true (in fact, it is also the default when set to false), the data will be distributed to each client in an identical and independent way (iid).
"iid": true
When customize
in iid is set to false or not set, the data will be distributed to each client in a Dirichlet distribution.
Beta is the parameter of the Dirichlet distribution.
"iid": {
"customize": false,
"beta": 0.5
}
or
"iid": {
"beta": 0.5
}
Customized non-iid settings are divided into two parts, one is for label non-iid setting and the other is for data quantity non-iid setting. Currently, only random generation is provided for data quantity, and personalized settings will be introduced in future versions.
When enabling the customized setting, you need to set customize
to true and set label
and data
separately.
"iid": {
"customize": true
}
Label setting is similar to staleness settings and supports three modes. The first one is mentioned in the configuration file.
"label": {
"step": 1,
"list": [10, 10, 30]
}
The above configuration will generate 10 clients with 1 label data, 10 clients with 2 label data, and 30 clients with 3 label data.
If step
is set to 2, the program will generate 10 clients with 1 label data, 10 clients with 3 label data, and 30 clients with 5 label data.
The second option is a two-dimensional array of random numbers, and the program will assign the array directly to the clients.
"label": {
"0": [1, 2, 3, 8],
"1": [2, 4],
"2": [4, 7],
"3": [0, 2, 3, 6, 9],
"4": [5]
}
The third option is a one-dimensional array, which represents the number of labels each client has, and the length of the array should be the same as the number of clients.
"label": {
"list": [4, 5, 10, 1, 2, 3, 4]
}
The above configuration sets the number of label data for each client: client 0 has 4 label data, client 1 has 5 label data, and so on.
Currently, there are two randomization methods for generating label non-iid data, one is pure randomization, which may lead to all clients missing one label, resulting in a decrease in accuracy (although the probability is extremely low). The other method uses shuffle algorithm to ensure that each label is selected, but it also leads to the inability to generate data with uneven label distributions. The shuffle algorithm is controlled by the shuffle parameter, as shown below:
"label": {
"shuffle": true,
"list": [4, 5, 10, 1, 2, 3, 4]
}
The data setting is relatively simple, currently there are two methods, one of which is empty.
"data": {}
That is, no non-iid setting is performed on the data quantity.
The second method is mentioned in the configuration file.
"data": {
"max": 500,
"min": 400
}
That is, the data quantity for each client will be randomly distributed between 400 and 500, and will be evenly distributed among the labels by default.
The data quantity distribution is still relatively simple at this point, and will be gradually improved in the future.
Currently, client replacement needs to inherit from AsyncClient
or SyncClient
, and the new parameters are passed into the class through the client
configuration item.
The multi-GPU feature of this project is not about multi-GPU parallel computing. Each client is still trained on a single GPU, but macroscopically, the clients run on multiple GPUs. That is, the training tasks of each client will be evenly distributed to the GPUs visible to the program
. The GPU bound to each client is specified at initialization and is not specified on each round of training. Therefore, it is still possible to have a serious imbalance in GPU load.
This feature is controlled by the multi_gpu
switch in the global settings.
Currently, there is a core issue in the framework that the communication between clients and servers is implemented using the multiprocessing
queues. However, when a CUDA tensor is received by the queue and retrieved by other threads, it can cause a memory leak and may cause the program to crash.
This bug is caused by PyTorch and the multiprocessing queue, and the current solution is to upload non-CUDA tensors to the queue and convert them to CUDA tensors during aggregation. Therefore, when adding aggregation algorithms, the following code will be needed:
updated_parameters = {}
for key, var in client_weights.items():
updated_parameters[key] = var.clone()
if torch.cuda.is_available():
updated_parameters[key] = updated_parameters[key].cuda()
Desperadoccy |
Jzj007 |
QQ: 527707607
email: [email protected]
Welcome to provide suggestions for the project~
if you'd like contribute to this project, please contact us.