Skip to content

Commit

Permalink
20241209 update
Browse files Browse the repository at this point in the history
  • Loading branch information
john committed Dec 9, 2024
1 parent a70e1ab commit bbefd23
Show file tree
Hide file tree
Showing 23 changed files with 3,157 additions and 1 deletion.
14 changes: 13 additions & 1 deletion README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@ The following figure shows an overview of our SlaugFL. The SlaugFL consists of t

We divide the implementation code into two parts: The Preparation Phase and The Collaborative Training Phase.

## Setup
Install packages in the requirements.

## The Preparation Phase


## The Collaborative Training Phase

1. Change the arguments in config.py
2. Run the following script:
```
python main.py --function=run_job
```

## Citation

Expand All @@ -27,4 +35,8 @@ If you find this code is useful to your research, please consider to cite our pa
year={2024},
publisher={IEEE}
}
```
```

## Reference code
1. FedLab: https://github.com/SMILELab-FL/FedLab
2. FedGen: https://github.com/zhuangdizhu/FedGen
53 changes: 53 additions & 0 deletions The_Collaborative_Training_Phase/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#coding:utf8
import warnings
import argparse
import sys

def parse_arguments(argv):


parser = argparse.ArgumentParser()
parser.add_argument('--function', type=str, help='Name of the function you called.', default="")
parser.add_argument('--lr_decay',help='learning rate decay rate',type=float,default=0.998)
parser.add_argument("--dataset", type=str, default="cifar100")
parser.add_argument("--dataset_mean", type=tuple, default= (0.4914, 0.4822, 0.4465)) #cifar10
parser.add_argument("--dataset_std", type=tuple, default= (0.2023, 0.1994, 0.2010)) #cifar10
parser.add_argument("--model_name", type=str, default="ResNet18")
parser.add_argument("--model_outputdim", type=int, default=10)
parser.add_argument("--train", type=int, default=1, choices=[0,1])
parser.add_argument("--algorithm", type=str, default="FedSlaug")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=0.01, help="Local learning rate")
parser.add_argument("--dataset_path", type=str, default="", help="the dataset path")
parser.add_argument("--seed", type=int, default=2023, help="")
parser.add_argument("--partition_method", type=str, default="dirichlet", help="")
parser.add_argument("--dirichlet_alpha",type=int, default=0.1, help="")
parser.add_argument("--total_clients", type=int, default=100, help="")
parser.add_argument('--gpu_num', type=str, default='0', help='choose which gpu to use')
parser.add_argument("--num_glob_rounds", type=int, default=300)
parser.add_argument("--local_epochs", type=int, default=10)
parser.add_argument("--num_clients_per_round", type=int, default=10, help="Number of Users per round")
parser.add_argument("--repeat_times", type=int, default=3, help="total repeat times")
parser.add_argument("--save_path", type=str, default="./results/FedSlaug", help="directory path to save results")
parser.add_argument("--eval_every", type=int, default=1, help="the number of rounds to evaluate the model performance. 1 is recommend here.")
parser.add_argument("--save_every", type=int, default=1, help="the number of rounds to save the model.")
parser.add_argument('--weight_decay', type=float, help='weight decay',default=5e-04)
parser.add_argument('--genmodel_weight_decay', type=float, help='weight decay',default=1e-04)
parser.add_argument('--GAN_type', type=int , help='0:case 2 weak gan, 1:case 1 strong gan', default=1)
parser.add_argument('--GAN_dir', type=str , help='model dir of trained gan',default='')
parser.add_argument('--GAN_name', type=str , help='model name of trained gan',default='')
args = parser.parse_args()

return parser.parse_args(argv)


flargs = parse_arguments(sys.argv[1:])

def print_parameters(_print,args):

for k, v in args.__dict__.items():
_print(k + " " + str(v))

if __name__=='__main__':
args = parse_arguments(sys.argv[1:])
print(type(args))
Empty file.
133 changes: 133 additions & 0 deletions The_Collaborative_Training_Phase/data/cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
import torchvision
import os
import sys
sys.path.append("..")
from tqdm import trange
import torchvision.transforms as T
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
import torch.utils.data as data
from data.data_division_utils.partition import CIFAR10Partitioner
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

num_clients = 10
hist_color = '#4169E1'

class client_cifar10(data.Dataset):
def __init__(self,indices, mode='train', data_root_dir=None, transform=None, target_transform=None):
self.cid = indices
self.client_dataset = torch.load(os.path.join(data_root_dir, mode, "data_{}_{}.pt".format(mode,self.cid)))
self.transform = transform
self.target_transform = target_transform
self.src_data, self.src_label = zip(*self.client_dataset)
self.data = list(self.src_data)
self.label = list(self.src_label)
self.src_datalen = len(self.src_label)
self.src_class_list = list(set(self.src_label))

def get_each_class_nm(self,label):
y_np = np.array(label)
datanm_byclass= []
class_list = list(set(label))
for i in class_list:
datanm_byclass.append(y_np[y_np==i].size)
return class_list,datanm_byclass

def __getitem__(self, index):
img, label = self.data[index], self.label[index]

if self.transform is not None:
img = self.transform(img)
else:
img = torch.from_numpy(img)

if self.target_transform is not None:
label = self.target_transform(label)

return img, label

def __len__(self):
# return len(self.client_dataset.x)
return len(self.data)

def rearrange_data_by_class(data, targets, n_class):
new_data = []
for i in range(n_class):
idx = targets == i
new_data.append(data[idx])
return new_data

def cifar10_hetero_dir_part(_print = print,seed=2023,dataset_path = "",save_path = "",num_clients=10,balance=None,partition="dirichlet",dir_alpha=0.3):
#Train test dataset partion
trainset = torchvision.datasets.CIFAR10(root=dataset_path,
train=True, download=True)
testset = torchvision.datasets.CIFAR10(root=dataset_path,
train=False, download=False)

partitioner = CIFAR10Partitioner(trainset.targets,
num_clients,
balance=balance,
partition=partition,
dir_alpha=dir_alpha,
seed=seed)

data_indices = partitioner.client_dict
class_nm = 10
client_traindata_path = os.path.join(save_path, "train")
client_testdata_path = os.path.join(save_path, "test")
if not os.path.exists(client_traindata_path):
os.makedirs(client_traindata_path)
if not os.path.exists(client_testdata_path):
os.makedirs(client_testdata_path)

trainsamples, trainlabels = [], []
for x, y in trainset:
trainsamples.append(x)
trainlabels.append(y)
testsamples = np.empty((len(testset),),dtype=object)
testlabels = np.empty((len(testset),),dtype=object)
for i, z in enumerate(testset):
testsamples[i]=z[0]
testlabels[i]=z[1]
rearrange_testsamples = rearrange_data_by_class(testsamples,testlabels,class_nm)
testdata_nmidx = {l:0 for l in [i for i in range(class_nm)]}
# print(testdata_nmidx)

for id, indices in data_indices.items():
traindata, trainlabel = [], []
for idx in indices:
x, y = trainsamples[idx], trainlabels[idx]
traindata.append(x)
trainlabel.append(y)

user_sampled_labels = list(set(trainlabel))
_print("client {}'s classes:{}".format(id,user_sampled_labels))
testdata, testlabel = [], []
for l in user_sampled_labels:
num_samples = int(len(rearrange_testsamples[l]) / num_clients )
assert num_samples + testdata_nmidx[l] <= len(rearrange_testsamples[l])
testdata += rearrange_testsamples[l][testdata_nmidx[l]:testdata_nmidx[l] + num_samples].tolist()
testlabel += (l * np.ones(num_samples,dtype = int)).tolist()
assert len(testdata) == len(testlabel), f"{len(testdata)} == {len(testlabel)}"
testdata_nmidx[l] += num_samples

train_dataset = [(x, y) for x, y in zip(traindata, trainlabel)]
test_dataset = [(x, y) for x, y in zip(testdata, testlabel)]
torch.save(
train_dataset,
os.path.join(client_traindata_path, "data_train_{}.pt".format(id)))
torch.save(
test_dataset,
os.path.join(client_testdata_path, "data_test_{}.pt".format(id)))

return partitioner




131 changes: 131 additions & 0 deletions The_Collaborative_Training_Phase/data/cifar100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch
import torchvision
import os
from tqdm import trange
import torchvision.transforms as T
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
import sys
sys.path.append("..")
import torch.utils.data as data
from data.data_division_utils.partition import CIFAR100Partitioner
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

num_clients = 10
hist_color = '#4169E1'

class client_cifar100(data.Dataset):
def __init__(self,indices, mode='train', data_root_dir=None, transform=None, target_transform=None):
self.cid = indices
self.client_dataset = torch.load(os.path.join(data_root_dir, mode, "data_{}_{}.pt".format(mode,self.cid)))
self.transform = transform
self.target_transform = target_transform
self.src_data, self.src_label = zip(*self.client_dataset)
self.data = list(self.src_data)
self.label = list(self.src_label)
self.src_datalen = len(self.src_label)
self.src_class_list = list(set(self.src_label))

def get_each_class_nm(self,label):
y_np = np.array(label)
datanm_byclass= []
class_list = list(set(label))
for i in class_list:
datanm_byclass.append(y_np[y_np==i].size)
return class_list,datanm_byclass

def __getitem__(self, index):
img, label = self.data[index], self.label[index]

if self.transform is not None:
img = self.transform(img)
else:
img = torch.from_numpy(img)

if self.target_transform is not None:
label = self.target_transform(label)

return img, label

def __len__(self):
# return len(self.client_dataset.x)
return len(self.data)

def rearrange_data_by_class(data, targets, n_class):
new_data = []
for i in range(n_class):
idx = targets == i
new_data.append(data[idx])
return new_data

def cifar100_hetero_dir_part(_print = print,seed=2023,dataset_path="",save_path = "",num_clients=10,balance=None,partition="dirichlet",dir_alpha=0.3):
#Train test dataset partion
trainset = torchvision.datasets.CIFAR100(root=dataset_path,
train=True, download=True)
testset = torchvision.datasets.CIFAR100(root=dataset_path,
train=False, download=False)

partitioner = CIFAR100Partitioner(trainset.targets,
num_clients,
balance=balance,
partition=partition,
dir_alpha=dir_alpha,
seed=seed)

data_indices = partitioner.client_dict
class_nm = 100
client_traindata_path = os.path.join(save_path, "train")
client_testdata_path = os.path.join(save_path, "test")
if not os.path.exists(client_traindata_path):
os.makedirs(client_traindata_path)
if not os.path.exists(client_testdata_path):
os.makedirs(client_testdata_path)

trainsamples, trainlabels = [], []
for x, y in trainset:
trainsamples.append(x)
trainlabels.append(y)
testsamples = np.empty((len(testset),),dtype=object)
testlabels = np.empty((len(testset),),dtype=object)
for i, z in enumerate(testset):
testsamples[i]=z[0]
testlabels[i]=z[1]
rearrange_testsamples = rearrange_data_by_class(testsamples,testlabels,class_nm)
testdata_nmidx = {l:0 for l in [i for i in range(class_nm)]}
# print(testdata_nmidx)

for id, indices in data_indices.items():
traindata, trainlabel = [], []
for idx in indices:
x, y = trainsamples[idx], trainlabels[idx]
traindata.append(x)
trainlabel.append(y)

user_sampled_labels = list(set(trainlabel))
_print("client {}'s classes:{}".format(id,user_sampled_labels))
testdata, testlabel = [], []
for l in user_sampled_labels:
num_samples = int(len(rearrange_testsamples[l]) / num_clients )
assert num_samples + testdata_nmidx[l] <= len(rearrange_testsamples[l])
testdata += rearrange_testsamples[l][testdata_nmidx[l]:testdata_nmidx[l] + num_samples].tolist()
testlabel += (l * np.ones(num_samples,dtype = int)).tolist()
assert len(testdata) == len(testlabel), f"{len(testdata)} == {len(testlabel)}"
testdata_nmidx[l] += num_samples

train_dataset = [(x, y) for x, y in zip(traindata, trainlabel)]
test_dataset = [(x, y) for x, y in zip(testdata, testlabel)]
torch.save(
train_dataset,
os.path.join(client_traindata_path, "data_train_{}.pt".format(id)))
torch.save(
test_dataset,
os.path.join(client_testdata_path, "data_test_{}.pt".format(id)))

return partitioner


Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .partition import DataPartitioner, BasicPartitioner, VisionPartitioner
from .partition import CIFAR10Partitioner, CIFAR100Partitioner, FMNISTPartitioner, MNISTPartitioner, \
SVHNPartitioner
from .partition import FCUBEPartitioner
from .partition import AdultPartitioner, RCV1Partitioner, CovtypePartitioner
Loading

0 comments on commit bbefd23

Please sign in to comment.