-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
john
committed
Dec 9, 2024
1 parent
a70e1ab
commit bbefd23
Showing
23 changed files
with
3,157 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#coding:utf8 | ||
import warnings | ||
import argparse | ||
import sys | ||
|
||
def parse_arguments(argv): | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--function', type=str, help='Name of the function you called.', default="") | ||
parser.add_argument('--lr_decay',help='learning rate decay rate',type=float,default=0.998) | ||
parser.add_argument("--dataset", type=str, default="cifar100") | ||
parser.add_argument("--dataset_mean", type=tuple, default= (0.4914, 0.4822, 0.4465)) #cifar10 | ||
parser.add_argument("--dataset_std", type=tuple, default= (0.2023, 0.1994, 0.2010)) #cifar10 | ||
parser.add_argument("--model_name", type=str, default="ResNet18") | ||
parser.add_argument("--model_outputdim", type=int, default=10) | ||
parser.add_argument("--train", type=int, default=1, choices=[0,1]) | ||
parser.add_argument("--algorithm", type=str, default="FedSlaug") | ||
parser.add_argument("--batch_size", type=int, default=64) | ||
parser.add_argument("--learning_rate", type=float, default=0.01, help="Local learning rate") | ||
parser.add_argument("--dataset_path", type=str, default="", help="the dataset path") | ||
parser.add_argument("--seed", type=int, default=2023, help="") | ||
parser.add_argument("--partition_method", type=str, default="dirichlet", help="") | ||
parser.add_argument("--dirichlet_alpha",type=int, default=0.1, help="") | ||
parser.add_argument("--total_clients", type=int, default=100, help="") | ||
parser.add_argument('--gpu_num', type=str, default='0', help='choose which gpu to use') | ||
parser.add_argument("--num_glob_rounds", type=int, default=300) | ||
parser.add_argument("--local_epochs", type=int, default=10) | ||
parser.add_argument("--num_clients_per_round", type=int, default=10, help="Number of Users per round") | ||
parser.add_argument("--repeat_times", type=int, default=3, help="total repeat times") | ||
parser.add_argument("--save_path", type=str, default="./results/FedSlaug", help="directory path to save results") | ||
parser.add_argument("--eval_every", type=int, default=1, help="the number of rounds to evaluate the model performance. 1 is recommend here.") | ||
parser.add_argument("--save_every", type=int, default=1, help="the number of rounds to save the model.") | ||
parser.add_argument('--weight_decay', type=float, help='weight decay',default=5e-04) | ||
parser.add_argument('--genmodel_weight_decay', type=float, help='weight decay',default=1e-04) | ||
parser.add_argument('--GAN_type', type=int , help='0:case 2 weak gan, 1:case 1 strong gan', default=1) | ||
parser.add_argument('--GAN_dir', type=str , help='model dir of trained gan',default='') | ||
parser.add_argument('--GAN_name', type=str , help='model name of trained gan',default='') | ||
args = parser.parse_args() | ||
|
||
return parser.parse_args(argv) | ||
|
||
|
||
flargs = parse_arguments(sys.argv[1:]) | ||
|
||
def print_parameters(_print,args): | ||
|
||
for k, v in args.__dict__.items(): | ||
_print(k + " " + str(v)) | ||
|
||
if __name__=='__main__': | ||
args = parse_arguments(sys.argv[1:]) | ||
print(type(args)) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import torch | ||
import torchvision | ||
import os | ||
import sys | ||
sys.path.append("..") | ||
from tqdm import trange | ||
import torchvision.transforms as T | ||
import numpy as np | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
from PIL import Image | ||
import seaborn as sns | ||
import torch.utils.data as data | ||
from data.data_division_utils.partition import CIFAR10Partitioner | ||
from torchvision import datasets | ||
from torch.utils.data import DataLoader | ||
import torchvision.transforms as transforms | ||
|
||
num_clients = 10 | ||
hist_color = '#4169E1' | ||
|
||
class client_cifar10(data.Dataset): | ||
def __init__(self,indices, mode='train', data_root_dir=None, transform=None, target_transform=None): | ||
self.cid = indices | ||
self.client_dataset = torch.load(os.path.join(data_root_dir, mode, "data_{}_{}.pt".format(mode,self.cid))) | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.src_data, self.src_label = zip(*self.client_dataset) | ||
self.data = list(self.src_data) | ||
self.label = list(self.src_label) | ||
self.src_datalen = len(self.src_label) | ||
self.src_class_list = list(set(self.src_label)) | ||
|
||
def get_each_class_nm(self,label): | ||
y_np = np.array(label) | ||
datanm_byclass= [] | ||
class_list = list(set(label)) | ||
for i in class_list: | ||
datanm_byclass.append(y_np[y_np==i].size) | ||
return class_list,datanm_byclass | ||
|
||
def __getitem__(self, index): | ||
img, label = self.data[index], self.label[index] | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
else: | ||
img = torch.from_numpy(img) | ||
|
||
if self.target_transform is not None: | ||
label = self.target_transform(label) | ||
|
||
return img, label | ||
|
||
def __len__(self): | ||
# return len(self.client_dataset.x) | ||
return len(self.data) | ||
|
||
def rearrange_data_by_class(data, targets, n_class): | ||
new_data = [] | ||
for i in range(n_class): | ||
idx = targets == i | ||
new_data.append(data[idx]) | ||
return new_data | ||
|
||
def cifar10_hetero_dir_part(_print = print,seed=2023,dataset_path = "",save_path = "",num_clients=10,balance=None,partition="dirichlet",dir_alpha=0.3): | ||
#Train test dataset partion | ||
trainset = torchvision.datasets.CIFAR10(root=dataset_path, | ||
train=True, download=True) | ||
testset = torchvision.datasets.CIFAR10(root=dataset_path, | ||
train=False, download=False) | ||
|
||
partitioner = CIFAR10Partitioner(trainset.targets, | ||
num_clients, | ||
balance=balance, | ||
partition=partition, | ||
dir_alpha=dir_alpha, | ||
seed=seed) | ||
|
||
data_indices = partitioner.client_dict | ||
class_nm = 10 | ||
client_traindata_path = os.path.join(save_path, "train") | ||
client_testdata_path = os.path.join(save_path, "test") | ||
if not os.path.exists(client_traindata_path): | ||
os.makedirs(client_traindata_path) | ||
if not os.path.exists(client_testdata_path): | ||
os.makedirs(client_testdata_path) | ||
|
||
trainsamples, trainlabels = [], [] | ||
for x, y in trainset: | ||
trainsamples.append(x) | ||
trainlabels.append(y) | ||
testsamples = np.empty((len(testset),),dtype=object) | ||
testlabels = np.empty((len(testset),),dtype=object) | ||
for i, z in enumerate(testset): | ||
testsamples[i]=z[0] | ||
testlabels[i]=z[1] | ||
rearrange_testsamples = rearrange_data_by_class(testsamples,testlabels,class_nm) | ||
testdata_nmidx = {l:0 for l in [i for i in range(class_nm)]} | ||
# print(testdata_nmidx) | ||
|
||
for id, indices in data_indices.items(): | ||
traindata, trainlabel = [], [] | ||
for idx in indices: | ||
x, y = trainsamples[idx], trainlabels[idx] | ||
traindata.append(x) | ||
trainlabel.append(y) | ||
|
||
user_sampled_labels = list(set(trainlabel)) | ||
_print("client {}'s classes:{}".format(id,user_sampled_labels)) | ||
testdata, testlabel = [], [] | ||
for l in user_sampled_labels: | ||
num_samples = int(len(rearrange_testsamples[l]) / num_clients ) | ||
assert num_samples + testdata_nmidx[l] <= len(rearrange_testsamples[l]) | ||
testdata += rearrange_testsamples[l][testdata_nmidx[l]:testdata_nmidx[l] + num_samples].tolist() | ||
testlabel += (l * np.ones(num_samples,dtype = int)).tolist() | ||
assert len(testdata) == len(testlabel), f"{len(testdata)} == {len(testlabel)}" | ||
testdata_nmidx[l] += num_samples | ||
|
||
train_dataset = [(x, y) for x, y in zip(traindata, trainlabel)] | ||
test_dataset = [(x, y) for x, y in zip(testdata, testlabel)] | ||
torch.save( | ||
train_dataset, | ||
os.path.join(client_traindata_path, "data_train_{}.pt".format(id))) | ||
torch.save( | ||
test_dataset, | ||
os.path.join(client_testdata_path, "data_test_{}.pt".format(id))) | ||
|
||
return partitioner | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import torch | ||
import torchvision | ||
import os | ||
from tqdm import trange | ||
import torchvision.transforms as T | ||
import numpy as np | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
from PIL import Image | ||
import seaborn as sns | ||
import sys | ||
sys.path.append("..") | ||
import torch.utils.data as data | ||
from data.data_division_utils.partition import CIFAR100Partitioner | ||
from torchvision import datasets | ||
from torch.utils.data import DataLoader | ||
import torchvision.transforms as transforms | ||
|
||
num_clients = 10 | ||
hist_color = '#4169E1' | ||
|
||
class client_cifar100(data.Dataset): | ||
def __init__(self,indices, mode='train', data_root_dir=None, transform=None, target_transform=None): | ||
self.cid = indices | ||
self.client_dataset = torch.load(os.path.join(data_root_dir, mode, "data_{}_{}.pt".format(mode,self.cid))) | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.src_data, self.src_label = zip(*self.client_dataset) | ||
self.data = list(self.src_data) | ||
self.label = list(self.src_label) | ||
self.src_datalen = len(self.src_label) | ||
self.src_class_list = list(set(self.src_label)) | ||
|
||
def get_each_class_nm(self,label): | ||
y_np = np.array(label) | ||
datanm_byclass= [] | ||
class_list = list(set(label)) | ||
for i in class_list: | ||
datanm_byclass.append(y_np[y_np==i].size) | ||
return class_list,datanm_byclass | ||
|
||
def __getitem__(self, index): | ||
img, label = self.data[index], self.label[index] | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
else: | ||
img = torch.from_numpy(img) | ||
|
||
if self.target_transform is not None: | ||
label = self.target_transform(label) | ||
|
||
return img, label | ||
|
||
def __len__(self): | ||
# return len(self.client_dataset.x) | ||
return len(self.data) | ||
|
||
def rearrange_data_by_class(data, targets, n_class): | ||
new_data = [] | ||
for i in range(n_class): | ||
idx = targets == i | ||
new_data.append(data[idx]) | ||
return new_data | ||
|
||
def cifar100_hetero_dir_part(_print = print,seed=2023,dataset_path="",save_path = "",num_clients=10,balance=None,partition="dirichlet",dir_alpha=0.3): | ||
#Train test dataset partion | ||
trainset = torchvision.datasets.CIFAR100(root=dataset_path, | ||
train=True, download=True) | ||
testset = torchvision.datasets.CIFAR100(root=dataset_path, | ||
train=False, download=False) | ||
|
||
partitioner = CIFAR100Partitioner(trainset.targets, | ||
num_clients, | ||
balance=balance, | ||
partition=partition, | ||
dir_alpha=dir_alpha, | ||
seed=seed) | ||
|
||
data_indices = partitioner.client_dict | ||
class_nm = 100 | ||
client_traindata_path = os.path.join(save_path, "train") | ||
client_testdata_path = os.path.join(save_path, "test") | ||
if not os.path.exists(client_traindata_path): | ||
os.makedirs(client_traindata_path) | ||
if not os.path.exists(client_testdata_path): | ||
os.makedirs(client_testdata_path) | ||
|
||
trainsamples, trainlabels = [], [] | ||
for x, y in trainset: | ||
trainsamples.append(x) | ||
trainlabels.append(y) | ||
testsamples = np.empty((len(testset),),dtype=object) | ||
testlabels = np.empty((len(testset),),dtype=object) | ||
for i, z in enumerate(testset): | ||
testsamples[i]=z[0] | ||
testlabels[i]=z[1] | ||
rearrange_testsamples = rearrange_data_by_class(testsamples,testlabels,class_nm) | ||
testdata_nmidx = {l:0 for l in [i for i in range(class_nm)]} | ||
# print(testdata_nmidx) | ||
|
||
for id, indices in data_indices.items(): | ||
traindata, trainlabel = [], [] | ||
for idx in indices: | ||
x, y = trainsamples[idx], trainlabels[idx] | ||
traindata.append(x) | ||
trainlabel.append(y) | ||
|
||
user_sampled_labels = list(set(trainlabel)) | ||
_print("client {}'s classes:{}".format(id,user_sampled_labels)) | ||
testdata, testlabel = [], [] | ||
for l in user_sampled_labels: | ||
num_samples = int(len(rearrange_testsamples[l]) / num_clients ) | ||
assert num_samples + testdata_nmidx[l] <= len(rearrange_testsamples[l]) | ||
testdata += rearrange_testsamples[l][testdata_nmidx[l]:testdata_nmidx[l] + num_samples].tolist() | ||
testlabel += (l * np.ones(num_samples,dtype = int)).tolist() | ||
assert len(testdata) == len(testlabel), f"{len(testdata)} == {len(testlabel)}" | ||
testdata_nmidx[l] += num_samples | ||
|
||
train_dataset = [(x, y) for x, y in zip(traindata, trainlabel)] | ||
test_dataset = [(x, y) for x, y in zip(testdata, testlabel)] | ||
torch.save( | ||
train_dataset, | ||
os.path.join(client_traindata_path, "data_train_{}.pt".format(id))) | ||
torch.save( | ||
test_dataset, | ||
os.path.join(client_testdata_path, "data_test_{}.pt".format(id))) | ||
|
||
return partitioner | ||
|
||
|
19 changes: 19 additions & 0 deletions
19
The_Collaborative_Training_Phase/data/data_division_utils/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .partition import DataPartitioner, BasicPartitioner, VisionPartitioner | ||
from .partition import CIFAR10Partitioner, CIFAR100Partitioner, FMNISTPartitioner, MNISTPartitioner, \ | ||
SVHNPartitioner | ||
from .partition import FCUBEPartitioner | ||
from .partition import AdultPartitioner, RCV1Partitioner, CovtypePartitioner |
Oops, something went wrong.