diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 28c39a6..426ec00 --- a/README.md +++ b/README.md @@ -9,11 +9,19 @@ The following figure shows an overview of our SlaugFL. The SlaugFL consists of t We divide the implementation code into two parts: The Preparation Phase and The Collaborative Training Phase. +## Setup +Install packages in the requirements. + ## The Preparation Phase ## The Collaborative Training Phase +1. Change the arguments in config.py +2. Run the following script: +``` +python main.py --function=run_job +``` ## Citation @@ -27,4 +35,8 @@ If you find this code is useful to your research, please consider to cite our pa year={2024}, publisher={IEEE} } -``` \ No newline at end of file +``` + +## Reference code +1. FedLab: https://github.com/SMILELab-FL/FedLab +2. FedGen: https://github.com/zhuangdizhu/FedGen \ No newline at end of file diff --git a/The_Collaborative_Training_Phase/config.py b/The_Collaborative_Training_Phase/config.py new file mode 100644 index 0000000..f52f383 --- /dev/null +++ b/The_Collaborative_Training_Phase/config.py @@ -0,0 +1,53 @@ +#coding:utf8 +import warnings +import argparse +import sys + +def parse_arguments(argv): + + + parser = argparse.ArgumentParser() + parser.add_argument('--function', type=str, help='Name of the function you called.', default="") + parser.add_argument('--lr_decay',help='learning rate decay rate',type=float,default=0.998) + parser.add_argument("--dataset", type=str, default="cifar100") + parser.add_argument("--dataset_mean", type=tuple, default= (0.4914, 0.4822, 0.4465)) #cifar10 + parser.add_argument("--dataset_std", type=tuple, default= (0.2023, 0.1994, 0.2010)) #cifar10 + parser.add_argument("--model_name", type=str, default="ResNet18") + parser.add_argument("--model_outputdim", type=int, default=10) + parser.add_argument("--train", type=int, default=1, choices=[0,1]) + parser.add_argument("--algorithm", type=str, default="FedSlaug") + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--learning_rate", type=float, default=0.01, help="Local learning rate") + parser.add_argument("--dataset_path", type=str, default="", help="the dataset path") + parser.add_argument("--seed", type=int, default=2023, help="") + parser.add_argument("--partition_method", type=str, default="dirichlet", help="") + parser.add_argument("--dirichlet_alpha",type=int, default=0.1, help="") + parser.add_argument("--total_clients", type=int, default=100, help="") + parser.add_argument('--gpu_num', type=str, default='0', help='choose which gpu to use') + parser.add_argument("--num_glob_rounds", type=int, default=300) + parser.add_argument("--local_epochs", type=int, default=10) + parser.add_argument("--num_clients_per_round", type=int, default=10, help="Number of Users per round") + parser.add_argument("--repeat_times", type=int, default=3, help="total repeat times") + parser.add_argument("--save_path", type=str, default="./results/FedSlaug", help="directory path to save results") + parser.add_argument("--eval_every", type=int, default=1, help="the number of rounds to evaluate the model performance. 1 is recommend here.") + parser.add_argument("--save_every", type=int, default=1, help="the number of rounds to save the model.") + parser.add_argument('--weight_decay', type=float, help='weight decay',default=5e-04) + parser.add_argument('--genmodel_weight_decay', type=float, help='weight decay',default=1e-04) + parser.add_argument('--GAN_type', type=int , help='0:case 2 weak gan, 1:case 1 strong gan', default=1) + parser.add_argument('--GAN_dir', type=str , help='model dir of trained gan',default='') + parser.add_argument('--GAN_name', type=str , help='model name of trained gan',default='') + args = parser.parse_args() + + return parser.parse_args(argv) + + +flargs = parse_arguments(sys.argv[1:]) + +def print_parameters(_print,args): + + for k, v in args.__dict__.items(): + _print(k + " " + str(v)) + +if __name__=='__main__': + args = parse_arguments(sys.argv[1:]) + print(type(args)) diff --git a/The_Collaborative_Training_Phase/data/__init__.py b/The_Collaborative_Training_Phase/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/The_Collaborative_Training_Phase/data/cifar10.py b/The_Collaborative_Training_Phase/data/cifar10.py new file mode 100644 index 0000000..9881216 --- /dev/null +++ b/The_Collaborative_Training_Phase/data/cifar10.py @@ -0,0 +1,133 @@ +import torch +import torchvision +import os +import sys +sys.path.append("..") +from tqdm import trange +import torchvision.transforms as T +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from PIL import Image +import seaborn as sns +import torch.utils.data as data +from data.data_division_utils.partition import CIFAR10Partitioner +from torchvision import datasets +from torch.utils.data import DataLoader +import torchvision.transforms as transforms + +num_clients = 10 +hist_color = '#4169E1' + +class client_cifar10(data.Dataset): + def __init__(self,indices, mode='train', data_root_dir=None, transform=None, target_transform=None): + self.cid = indices + self.client_dataset = torch.load(os.path.join(data_root_dir, mode, "data_{}_{}.pt".format(mode,self.cid))) + self.transform = transform + self.target_transform = target_transform + self.src_data, self.src_label = zip(*self.client_dataset) + self.data = list(self.src_data) + self.label = list(self.src_label) + self.src_datalen = len(self.src_label) + self.src_class_list = list(set(self.src_label)) + + def get_each_class_nm(self,label): + y_np = np.array(label) + datanm_byclass= [] + class_list = list(set(label)) + for i in class_list: + datanm_byclass.append(y_np[y_np==i].size) + return class_list,datanm_byclass + + def __getitem__(self, index): + img, label = self.data[index], self.label[index] + + if self.transform is not None: + img = self.transform(img) + else: + img = torch.from_numpy(img) + + if self.target_transform is not None: + label = self.target_transform(label) + + return img, label + + def __len__(self): + # return len(self.client_dataset.x) + return len(self.data) + +def rearrange_data_by_class(data, targets, n_class): + new_data = [] + for i in range(n_class): + idx = targets == i + new_data.append(data[idx]) + return new_data + +def cifar10_hetero_dir_part(_print = print,seed=2023,dataset_path = "",save_path = "",num_clients=10,balance=None,partition="dirichlet",dir_alpha=0.3): + #Train test dataset partion + trainset = torchvision.datasets.CIFAR10(root=dataset_path, + train=True, download=True) + testset = torchvision.datasets.CIFAR10(root=dataset_path, + train=False, download=False) + + partitioner = CIFAR10Partitioner(trainset.targets, + num_clients, + balance=balance, + partition=partition, + dir_alpha=dir_alpha, + seed=seed) + + data_indices = partitioner.client_dict + class_nm = 10 + client_traindata_path = os.path.join(save_path, "train") + client_testdata_path = os.path.join(save_path, "test") + if not os.path.exists(client_traindata_path): + os.makedirs(client_traindata_path) + if not os.path.exists(client_testdata_path): + os.makedirs(client_testdata_path) + + trainsamples, trainlabels = [], [] + for x, y in trainset: + trainsamples.append(x) + trainlabels.append(y) + testsamples = np.empty((len(testset),),dtype=object) + testlabels = np.empty((len(testset),),dtype=object) + for i, z in enumerate(testset): + testsamples[i]=z[0] + testlabels[i]=z[1] + rearrange_testsamples = rearrange_data_by_class(testsamples,testlabels,class_nm) + testdata_nmidx = {l:0 for l in [i for i in range(class_nm)]} + # print(testdata_nmidx) + + for id, indices in data_indices.items(): + traindata, trainlabel = [], [] + for idx in indices: + x, y = trainsamples[idx], trainlabels[idx] + traindata.append(x) + trainlabel.append(y) + + user_sampled_labels = list(set(trainlabel)) + _print("client {}'s classes:{}".format(id,user_sampled_labels)) + testdata, testlabel = [], [] + for l in user_sampled_labels: + num_samples = int(len(rearrange_testsamples[l]) / num_clients ) + assert num_samples + testdata_nmidx[l] <= len(rearrange_testsamples[l]) + testdata += rearrange_testsamples[l][testdata_nmidx[l]:testdata_nmidx[l] + num_samples].tolist() + testlabel += (l * np.ones(num_samples,dtype = int)).tolist() + assert len(testdata) == len(testlabel), f"{len(testdata)} == {len(testlabel)}" + testdata_nmidx[l] += num_samples + + train_dataset = [(x, y) for x, y in zip(traindata, trainlabel)] + test_dataset = [(x, y) for x, y in zip(testdata, testlabel)] + torch.save( + train_dataset, + os.path.join(client_traindata_path, "data_train_{}.pt".format(id))) + torch.save( + test_dataset, + os.path.join(client_testdata_path, "data_test_{}.pt".format(id))) + + return partitioner + + + + diff --git a/The_Collaborative_Training_Phase/data/cifar100.py b/The_Collaborative_Training_Phase/data/cifar100.py new file mode 100644 index 0000000..26ac431 --- /dev/null +++ b/The_Collaborative_Training_Phase/data/cifar100.py @@ -0,0 +1,131 @@ +import torch +import torchvision +import os +from tqdm import trange +import torchvision.transforms as T +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from PIL import Image +import seaborn as sns +import sys +sys.path.append("..") +import torch.utils.data as data +from data.data_division_utils.partition import CIFAR100Partitioner +from torchvision import datasets +from torch.utils.data import DataLoader +import torchvision.transforms as transforms + +num_clients = 10 +hist_color = '#4169E1' + +class client_cifar100(data.Dataset): + def __init__(self,indices, mode='train', data_root_dir=None, transform=None, target_transform=None): + self.cid = indices + self.client_dataset = torch.load(os.path.join(data_root_dir, mode, "data_{}_{}.pt".format(mode,self.cid))) + self.transform = transform + self.target_transform = target_transform + self.src_data, self.src_label = zip(*self.client_dataset) + self.data = list(self.src_data) + self.label = list(self.src_label) + self.src_datalen = len(self.src_label) + self.src_class_list = list(set(self.src_label)) + + def get_each_class_nm(self,label): + y_np = np.array(label) + datanm_byclass= [] + class_list = list(set(label)) + for i in class_list: + datanm_byclass.append(y_np[y_np==i].size) + return class_list,datanm_byclass + + def __getitem__(self, index): + img, label = self.data[index], self.label[index] + + if self.transform is not None: + img = self.transform(img) + else: + img = torch.from_numpy(img) + + if self.target_transform is not None: + label = self.target_transform(label) + + return img, label + + def __len__(self): + # return len(self.client_dataset.x) + return len(self.data) + +def rearrange_data_by_class(data, targets, n_class): + new_data = [] + for i in range(n_class): + idx = targets == i + new_data.append(data[idx]) + return new_data + +def cifar100_hetero_dir_part(_print = print,seed=2023,dataset_path="",save_path = "",num_clients=10,balance=None,partition="dirichlet",dir_alpha=0.3): + #Train test dataset partion + trainset = torchvision.datasets.CIFAR100(root=dataset_path, + train=True, download=True) + testset = torchvision.datasets.CIFAR100(root=dataset_path, + train=False, download=False) + + partitioner = CIFAR100Partitioner(trainset.targets, + num_clients, + balance=balance, + partition=partition, + dir_alpha=dir_alpha, + seed=seed) + + data_indices = partitioner.client_dict + class_nm = 100 + client_traindata_path = os.path.join(save_path, "train") + client_testdata_path = os.path.join(save_path, "test") + if not os.path.exists(client_traindata_path): + os.makedirs(client_traindata_path) + if not os.path.exists(client_testdata_path): + os.makedirs(client_testdata_path) + + trainsamples, trainlabels = [], [] + for x, y in trainset: + trainsamples.append(x) + trainlabels.append(y) + testsamples = np.empty((len(testset),),dtype=object) + testlabels = np.empty((len(testset),),dtype=object) + for i, z in enumerate(testset): + testsamples[i]=z[0] + testlabels[i]=z[1] + rearrange_testsamples = rearrange_data_by_class(testsamples,testlabels,class_nm) + testdata_nmidx = {l:0 for l in [i for i in range(class_nm)]} + # print(testdata_nmidx) + + for id, indices in data_indices.items(): + traindata, trainlabel = [], [] + for idx in indices: + x, y = trainsamples[idx], trainlabels[idx] + traindata.append(x) + trainlabel.append(y) + + user_sampled_labels = list(set(trainlabel)) + _print("client {}'s classes:{}".format(id,user_sampled_labels)) + testdata, testlabel = [], [] + for l in user_sampled_labels: + num_samples = int(len(rearrange_testsamples[l]) / num_clients ) + assert num_samples + testdata_nmidx[l] <= len(rearrange_testsamples[l]) + testdata += rearrange_testsamples[l][testdata_nmidx[l]:testdata_nmidx[l] + num_samples].tolist() + testlabel += (l * np.ones(num_samples,dtype = int)).tolist() + assert len(testdata) == len(testlabel), f"{len(testdata)} == {len(testlabel)}" + testdata_nmidx[l] += num_samples + + train_dataset = [(x, y) for x, y in zip(traindata, trainlabel)] + test_dataset = [(x, y) for x, y in zip(testdata, testlabel)] + torch.save( + train_dataset, + os.path.join(client_traindata_path, "data_train_{}.pt".format(id))) + torch.save( + test_dataset, + os.path.join(client_testdata_path, "data_test_{}.pt".format(id))) + + return partitioner + + diff --git a/The_Collaborative_Training_Phase/data/data_division_utils/__init__.py b/The_Collaborative_Training_Phase/data/data_division_utils/__init__.py new file mode 100644 index 0000000..a5bbeb1 --- /dev/null +++ b/The_Collaborative_Training_Phase/data/data_division_utils/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .partition import DataPartitioner, BasicPartitioner, VisionPartitioner +from .partition import CIFAR10Partitioner, CIFAR100Partitioner, FMNISTPartitioner, MNISTPartitioner, \ + SVHNPartitioner +from .partition import FCUBEPartitioner +from .partition import AdultPartitioner, RCV1Partitioner, CovtypePartitioner diff --git a/The_Collaborative_Training_Phase/data/data_division_utils/functional.py b/The_Collaborative_Training_Phase/data/data_division_utils/functional.py new file mode 100644 index 0000000..8a7cbef --- /dev/null +++ b/The_Collaborative_Training_Phase/data/data_division_utils/functional.py @@ -0,0 +1,475 @@ +# Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd +import warnings + + +def split_indices(num_cumsum, rand_perm): + """Splice the sample index list given number of each client. + + Args: + num_cumsum (np.ndarray): Cumulative sum of sample number for each client. + rand_perm (list): List of random sample index. + + Returns: + dict: ``{ client_id: indices}``. + + """ + client_indices_pairs = [(cid, idxs) for cid, idxs in + enumerate(np.split(rand_perm, num_cumsum)[:-1])] + client_dict = dict(client_indices_pairs) + return client_dict + + +def balance_split(num_clients, num_samples): + """Assign same sample sample for each client. + + Args: + num_clients (int): Number of clients for partition. + num_samples (int): Total number of samples. + + Returns: + numpy.ndarray: A numpy array consisting ``num_clients`` integer elements, each represents sample number of corresponding clients. + + """ + num_samples_per_client = int(num_samples / num_clients) + client_sample_nums = (np.ones(num_clients) * num_samples_per_client).astype( + int) + return client_sample_nums + + +def lognormal_unbalance_split(num_clients, num_samples, unbalance_sgm): + """Assign different sample number for each client using Log-Normal distribution. + + Sample numbers for clients are drawn from Log-Normal distribution. + + Args: + num_clients (int): Number of clients for partition. + num_samples (int): Total number of samples. + unbalance_sgm (float): Log-normal variance. When equals to ``0``, the partition is equal to :func:`balance_partition`. + + Returns: + numpy.ndarray: A numpy array consisting ``num_clients`` integer elements, each represents sample number of corresponding clients. + + """ + num_samples_per_client = int(num_samples / num_clients) + if unbalance_sgm != 0: + client_sample_nums = np.random.lognormal(mean=np.log(num_samples_per_client), + sigma=unbalance_sgm, + size=num_clients) + client_sample_nums = ( + client_sample_nums / np.sum(client_sample_nums) * num_samples).astype(int) + diff = np.sum(client_sample_nums) - num_samples # diff <= 0 + + # Add/Subtract the excess number starting from first client + if diff != 0: + for cid in range(num_clients): + if client_sample_nums[cid] > diff: + client_sample_nums[cid] -= diff + break + else: + client_sample_nums = (np.ones(num_clients) * num_samples_per_client).astype(int) + + return client_sample_nums + + +def dirichlet_unbalance_split(num_clients, num_samples, alpha): + """Assign different sample number for each client using Dirichlet distribution. + + Sample numbers for clients are drawn from Dirichlet distribution. + + Args: + num_clients (int): Number of clients for partition. + num_samples (int): Total number of samples. + alpha (float): Dirichlet concentration parameter + + Returns: + numpy.ndarray: A numpy array consisting ``num_clients`` integer elements, each represents sample number of corresponding clients. + + """ + min_size = 0 + while min_size < 10: + proportions = np.random.dirichlet(np.repeat(alpha, num_clients)) + proportions = proportions / proportions.sum() + min_size = np.min(proportions * num_samples) + + client_sample_nums = (proportions * num_samples).astype(int) + return client_sample_nums + + +def homo_partition(client_sample_nums, num_samples): + """Partition data indices in IID way given sample numbers for each clients. + + Args: + client_sample_nums (numpy.ndarray): Sample numbers for each clients. + num_samples (int): Number of samples. + + Returns: + dict: ``{ client_id: indices}``. + + """ + rand_perm = np.random.permutation(num_samples) + num_cumsum = np.cumsum(client_sample_nums).astype(int) + client_dict = split_indices(num_cumsum, rand_perm) + return client_dict + + +def hetero_dir_partition(targets, num_clients, num_classes, dir_alpha, min_require_size=None): + """ + + Non-iid partition based on Dirichlet distribution. The method is from "hetero-dir" partition of + `Bayesian Nonparametric Federated Learning of Neural Networks `_ + and `Federated Learning with Matched Averaging `_. + + This method simulates heterogeneous partition for which number of data points and class + proportions are unbalanced. Samples will be partitioned into :math:`J` clients by sampling + :math:`p_k \sim \\text{Dir}_{J}({\\alpha})` and allocating a :math:`p_{p,j}` proportion of the + samples of class :math:`k` to local client :math:`j`. + + Sample number for each client is decided in this function. + + Args: + targets (list or numpy.ndarray): Sample targets. Unshuffled preferred. + num_clients (int): Number of clients for partition. + num_classes (int): Number of classes in samples. + dir_alpha (float): Parameter alpha for Dirichlet distribution. + min_require_size (int, optional): Minimum required sample number for each client. If set to ``None``, then equals to ``num_classes``. + + Returns: + dict: ``{ client_id: indices}``. + """ + if min_require_size is None: + min_require_size = num_classes + + if not isinstance(targets, np.ndarray): + targets = np.array(targets) + num_samples = targets.shape[0] + + min_size = 0 + while min_size < min_require_size: + idx_batch = [[] for _ in range(num_clients)] + # for each class in the dataset + for k in range(num_classes): + idx_k = np.where(targets == k)[0] + np.random.shuffle(idx_k) + proportions = np.random.dirichlet( + np.repeat(dir_alpha, num_clients)) + # Balance + proportions = np.array( + [p * (len(idx_j) < num_samples / num_clients) for p, idx_j in + zip(proportions, idx_batch)]) + proportions = proportions / proportions.sum() + proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] + idx_batch = [idx_j + idx.tolist() for idx_j, idx in + zip(idx_batch, np.split(idx_k, proportions))] + min_size = min([len(idx_j) for idx_j in idx_batch]) + + client_dict = dict() + for cid in range(num_clients): + np.random.shuffle(idx_batch[cid]) + client_dict[cid] = np.array(idx_batch[cid]) + + return client_dict + + +def shards_partition(targets, num_clients, num_shards): + """Non-iid partition used in FedAvg `paper `_. + + Args: + targets (list or numpy.ndarray): Sample targets. Unshuffled preferred. + num_clients (int): Number of clients for partition. + num_shards (int): Number of shards in partition. + + Returns: + dict: ``{ client_id: indices}``. + + """ + if not isinstance(targets, np.ndarray): + targets = np.array(targets) + num_samples = targets.shape[0] + + size_shard = int(num_samples / num_shards) + if num_samples % num_shards != 0: + warnings.warn("warning: length of dataset isn't divided exactly by num_shards. " + "Some samples will be dropped.") + + shards_per_client = int(num_shards / num_clients) + if num_shards % num_clients != 0: + warnings.warn("warning: num_shards isn't divided exactly by num_clients. " + "Some shards will be dropped.") + + indices = np.arange(num_samples) + # sort sample indices according to labels + indices_targets = np.vstack((indices, targets)) + indices_targets = indices_targets[:, indices_targets[1, :].argsort()] + # corresponding labels after sorting are [0, .., 0, 1, ..., 1, ...] + sorted_indices = indices_targets[0, :] + + # permute shards idx, and slice shards_per_client shards for each client + rand_perm = np.random.permutation(num_shards) + num_client_shards = np.ones(num_clients) * shards_per_client + # sample index must be int + num_cumsum = np.cumsum(num_client_shards).astype(int) + # shard indices for each client + client_shards_dict = split_indices(num_cumsum, rand_perm) + + # map shard idx to sample idx for each client + client_dict = dict() + for cid in range(num_clients): + shards_set = client_shards_dict[cid] + current_indices = [ + sorted_indices[shard_id * size_shard: (shard_id + 1) * size_shard] + for shard_id in shards_set] + client_dict[cid] = np.concatenate(current_indices, axis=0) + + return client_dict + + +def client_inner_dirichlet_partition(targets, num_clients, num_classes, dir_alpha, + client_sample_nums, verbose=True): + """Non-iid Dirichlet partition. + + The method is from The method is from paper `Federated Learning Based on Dynamic Regularization `_. + This function can be used by given specific sample number for all clients ``client_sample_nums``. + It's different from :func:`hetero_dir_partition`. + + Args: + targets (list or numpy.ndarray): Sample targets. + num_clients (int): Number of clients for partition. + num_classes (int): Number of classes in samples. + dir_alpha (float): Parameter alpha for Dirichlet distribution. + client_sample_nums (numpy.ndarray): A numpy array consisting ``num_clients`` integer elements, each represents sample number of corresponding clients. + verbose (bool, optional): Whether to print partition process. Default as ``True``. + + Returns: + dict: ``{ client_id: indices}``. + + """ + if not isinstance(targets, np.ndarray): + targets = np.array(targets) + + rand_perm = np.random.permutation(targets.shape[0]) + targets = targets[rand_perm] + + class_priors = np.random.dirichlet(alpha=[dir_alpha] * num_classes, + size=num_clients) + prior_cumsum = np.cumsum(class_priors, axis=1) + idx_list = [np.where(targets == i)[0] for i in range(num_classes)] + class_amount = [len(idx_list[i]) for i in range(num_classes)] + + client_indices = [np.zeros(client_sample_nums[cid]).astype(np.int64) for cid in + range(num_clients)] + + while np.sum(client_sample_nums) != 0: + curr_cid = np.random.randint(num_clients) + # If current node is full resample a client + if verbose: + print('Remaining Data: %d' % np.sum(client_sample_nums)) + if client_sample_nums[curr_cid] <= 0: + continue + client_sample_nums[curr_cid] -= 1 + curr_prior = prior_cumsum[curr_cid] + while True: + curr_class = np.argmax(np.random.uniform() <= curr_prior) + # Redraw class label if no rest in current class samples + if class_amount[curr_class] <= 0: + continue + class_amount[curr_class] -= 1 + client_indices[curr_cid][client_sample_nums[curr_cid]] = \ + idx_list[curr_class][class_amount[curr_class]] + + break + + client_dict = {cid: client_indices[cid] for cid in range(num_clients)} + return client_dict + + +def label_skew_quantity_based_partition(targets, num_clients, num_classes, major_classes_num): + """Label-skew:quantity-based partition. + + For details, please check `Federated Learning on Non-IID Data Silos: An Experimental Study `_. + + Args: + targets (List or np.ndarray): Labels od dataset. + num_clients (int): Number of clients. + num_classes (int): Number of unique classes. + major_classes_num (int): Number of classes for each client, should be less then ``num_classes``. + + Returns: + dict: ``{ client_id: indices}``. + + """ + if not isinstance(targets, np.ndarray): + targets = np.array(targets) + + idx_batch = [np.ndarray(0, dtype=np.int64) for _ in range(num_clients)] + # only for major_classes_num < num_classes. + # if major_classes_num = num_classes, it equals to IID partition + times = [0 for _ in range(num_classes)] + contain = [] + for cid in range(num_clients): + current = [cid % num_classes] + times[cid % num_classes] += 1 + j = 1 + while j < major_classes_num: + ind = np.random.randint(num_classes) + if ind not in current: + j += 1 + current.append(ind) + times[ind] += 1 + contain.append(current) + + for k in range(num_classes): + idx_k = np.where(targets == k)[0] + np.random.shuffle(idx_k) + split = np.array_split(idx_k, times[k]) + ids = 0 + for cid in range(num_clients): + if k in contain[cid]: + idx_batch[cid] = np.append(idx_batch[cid], split[ids]) + ids += 1 + + client_dict = {cid: idx_batch[cid] for cid in range(num_clients)} + return client_dict + + +def fcube_synthetic_partition(data): + """Feature-distribution-skew:synthetic partition. + + Synthetic partition for FCUBE dataset. This partition is from `Federated Learning on Non-IID Data Silos: An Experimental Study `_. + + Args: + data (np.ndarray): Data of dataset :class:`FCUBE`. + + Returns: + dict: ``{ client_id: indices}``. + """ + num_clients = 4 + client_indices = [[] for _ in range(num_clients)] + for idx, sample in enumerate(data): + p1, p2, p3 = sample + if (p1 > 0 and p2 > 0 and p3 > 0) or (p1 < 0 and p2 < 0 and p3 < 0): + client_indices[0].append(idx) + elif (p1 > 0 and p2 > 0 and p3 < 0) or (p1 < 0 and p2 < 0 and p3 > 0): + client_indices[1].append(idx) + elif (p1 > 0 and p2 < 0 and p3 > 0) or (p1 < 0 and p2 > 0 and p3 < 0): + client_indices[2].append(idx) + else: + client_indices[3].append(idx) + client_dict = {cid: np.array(client_indices[cid]).astype(int) for cid in range(num_clients)} + return client_dict + + +def samples_num_count(client_dict, num_clients): + """Return sample count for all clients in ``client_dict``. + + Args: + client_dict (dict): Data partition result for different clients. + num_clients (int): Total number of clients. + + Returns: + pandas.DataFrame + + """ + client_samples_nums = [[cid, client_dict[cid].shape[0]] for cid in + range(num_clients)] + client_sample_count = pd.DataFrame(data=client_samples_nums, + columns=['client', 'num_samples']).set_index('client') + return client_sample_count + +def noniid_slicing(dataset, num_clients, num_shards): + """Slice a dataset for non-IID. + + Args: + dataset (torch.utils.data.Dataset): Dataset to slice. + num_clients (int): Number of client. + num_shards (int): Number of shards. + + Notes: + The size of a shard equals to ``int(len(dataset)/num_shards)``. + Each client will get ``int(num_shards/num_clients)`` shards. + + Returns: + dict: ``{ 0: indices of dataset, 1: indices of dataset, ..., k: indices of dataset }`` + """ + total_sample_nums = len(dataset) + size_of_shards = int(total_sample_nums / num_shards) + if total_sample_nums % num_shards != 0: + warnings.warn( + "warning: the length of dataset isn't divided exactly by num_shard.some samples will be dropped." + ) + # the number of shards that each one of clients can get + shard_pc = int(num_shards / num_clients) + if num_shards % num_clients != 0: + warnings.warn( + "warning: num_shard isn't divided exactly by num_clients. some samples will be dropped." + ) + + dict_users = {i: np.array([], dtype='int64') for i in range(num_clients)} + + labels = np.array(dataset.targets) + idxs = np.arange(total_sample_nums) + + # sort sample indices according to labels + idxs_labels = np.vstack((idxs, labels)) + idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] + idxs = idxs_labels[0, :] # corresponding labels after sorting are [0, .., 0, 1, ..., 1, ...] + + # assign + idx_shard = [i for i in range(num_shards)] + for i in range(num_clients): + rand_set = set(np.random.choice(idx_shard, shard_pc, replace=False)) + idx_shard = list(set(idx_shard) - rand_set) + for rand in rand_set: + dict_users[i] = np.concatenate( + (dict_users[i], + idxs[rand * size_of_shards:(rand + 1) * size_of_shards]), + axis=0) + + return dict_users + + +def random_slicing(dataset, num_clients): + """Slice a dataset randomly and equally for IID. + + Args: + dataset (torch.utils.data.Dataset): a dataset for slicing. + num_clients (int): the number of client. + + Returns: + dict: ``{ 0: indices of dataset, 1: indices of dataset, ..., k: indices of dataset }`` + """ + num_items = int(len(dataset) / num_clients) + dict_users, all_idxs = {}, [i for i in range(len(dataset))] + for i in range(num_clients): + dict_users[i] = list( + np.random.choice(all_idxs, num_items, replace=False)) + all_idxs = list(set(all_idxs) - set(dict_users[i])) + return dict_users + +if __name__ == '__main__': + from torchvision import datasets, transforms + import torch + import numpy as np + dataset_train = datasets.CIFAR10('/home/demo2/FL_data/cifar-10/cifar-10-batches-py', train=True, download=False, transform=None) + targets = dataset_train.targets + client_idx=hetero_dir_partition(targets,20,10,0.05) + tensor = np.zeros((20,10)) + np.set_printoptions(suppress=True) + for cli in range(len(client_idx)): + for index in client_idx[cli]: + tensor[cli][targets[index]] +=1 + print(tensor) diff --git a/The_Collaborative_Training_Phase/data/data_division_utils/partition.py b/The_Collaborative_Training_Phase/data/data_division_utils/partition.py new file mode 100644 index 0000000..b051d4e --- /dev/null +++ b/The_Collaborative_Training_Phase/data/data_division_utils/partition.py @@ -0,0 +1,449 @@ +# Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +import numpy as np + +from . import functional as F + + +class DataPartitioner(ABC): + """Base class for data partition in federated learning. + + Examples of :class:`DataPartitioner`: :class:`BasicPartitioner`, :class:`CIFAR10Partitioner`. + + Details and tutorials of different data partition and datasets, please check `Federated Dataset and DataPartitioner `_. + """ + def __init__(self): + pass + + @abstractmethod + def _perform_partition(self): + raise NotImplementedError() + + @abstractmethod + def __getitem__(self, index): + raise NotImplementedError() + + @abstractmethod + def __len__(self): + raise NotImplementedError() + + +class CIFAR10Partitioner(DataPartitioner): + """CIFAR10 data partitioner. + + Partition CIFAR10 given specific client number. Currently 6 supported partition schemes can be + achieved by passing different combination of parameters in initialization: + + - ``balance=None`` + + - ``partition="dirichlet"``: non-iid partition used in + `Bayesian Nonparametric Federated Learning of Neural Networks `_ + and `Federated Learning with Matched Averaging `_. Refer + to :func:`fedlab.utils.dataset.functional.hetero_dir_partition` for more information. + + - ``partition="shards"``: non-iid method used in FedAvg `paper `_. + Refer to :func:`fedlab.utils.dataset.functional.shards_partition` for more information. + + + - ``balance=True``: "Balance" refers to FL scenario that sample numbers for different clients + are the same. Refer to :func:`fedlab.utils.dataset.functional.balance_partition` for more + information. + + - ``partition="iid"``: Random select samples from complete dataset given sample number for + each client. + + - ``partition="dirichlet"``: Refer to :func:`fedlab.utils.dataset.functional.client_inner_dirichlet_partition` + for more information. + + - ``balance=False``: "Unbalance" refers to FL scenario that sample numbers for different clients + are different. For unbalance method, sample number for each client is drown from Log-Normal + distribution with variance ``unbalanced_sgm``. When ``unbalanced_sgm=0``, partition is + balanced. Refer to :func:`fedlab.utils.dataset.functional.lognormal_unbalance_partition` + for more information. The method is from paper `Federated Learning Based on Dynamic Regularization `_. + + - ``partition="iid"``: Random select samples from complete dataset given sample number for + each client. + + - ``partition="dirichlet"``: Refer to :func:`fedlab.utils.dataset.functional.client_inner_dirichlet_partition` + for more information. + + For detail usage, please check `Federated Dataset and DataPartitioner `_. + + Args: + targets (list or numpy.ndarray): Targets of dataset for partition. Each element is in range of [0, 1, ..., 9]. + num_clients (int): Number of clients for data partition. + balance (bool, optional): Balanced partition over all clients or not. Default as ``True``. + partition (str, optional): Partition type, only ``"iid"``, ``shards``, ``"dirichlet"`` are supported. Default as ``"iid"``. + unbalance_sgm (float, optional): Log-normal distribution variance for unbalanced data partition over clients. Default as ``0`` for balanced partition. + num_shards (int, optional): Number of shards in non-iid ``"shards"`` partition. Only works if ``partition="shards"``. Default as ``None``. + dir_alpha (float, optional): Dirichlet distribution parameter for non-iid partition. Only works if ``partition="dirichlet"``. Default as ``None``. + verbose (bool, optional): Whether to print partition process. Default as ``True``. + min_require_size (int, optional): Minimum required sample number for each client. If set to ``None``, then equals to ``num_classes``. Only works if ``partition="noniid-labeldir"``. + seed (int, optional): Random seed. Default as ``None``. + """ + + num_classes = 10 + + def __init__(self, targets, num_clients, + balance=True, partition="iid", + unbalance_sgm=0, + num_shards=None, + dir_alpha=None, + verbose=True, + min_require_size=None, + seed=None): + + self.targets = np.array(targets) # with shape (num_samples,) + self.num_samples = self.targets.shape[0] + self.num_clients = num_clients + self.client_dict = dict() + self.partition = partition + self.balance = balance + self.dir_alpha = dir_alpha + self.num_shards = num_shards + self.unbalance_sgm = unbalance_sgm + self.verbose = verbose + self.min_require_size = min_require_size + # self.rng = np.random.default_rng(seed) # rng currently not supports randint + np.random.seed(seed) + + # partition scheme check + if balance is None: + assert partition in ["dirichlet", "shards"], f"When balance=None, 'partition' only " \ + f"accepts 'dirichlet' and 'shards'." + elif isinstance(balance, bool): + assert partition in ["iid", "dirichlet"], f"When balance is bool, 'partition' only " \ + f"accepts 'dirichlet' and 'iid'." + else: + raise ValueError(f"'balance' can only be NoneType or bool, not {type(balance)}.") + + # perform partition according to setting + self.client_dict = self._perform_partition() + # get sample number count for each client + self.client_sample_count = F.samples_num_count(self.client_dict, self.num_clients) + + def _perform_partition(self): + if self.balance is None: + if self.partition == "dirichlet": + client_dict = F.hetero_dir_partition(self.targets, + self.num_clients, + self.num_classes, + self.dir_alpha, + min_require_size=self.min_require_size) + + else: # partition is 'shards' + client_dict = F.shards_partition(self.targets, self.num_clients, self.num_shards) + + else: # if balance is True or False + # perform sample number balance/unbalance partition over all clients + if self.balance is True: + client_sample_nums = F.balance_split(self.num_clients, self.num_samples) + else: + client_sample_nums = F.lognormal_unbalance_split(self.num_clients, + self.num_samples, + self.unbalance_sgm) + + # perform iid/dirichlet partition for each client + if self.partition == "iid": + client_dict = F.homo_partition(client_sample_nums, self.num_samples) + else: # for dirichlet + client_dict = F.client_inner_dirichlet_partition(self.targets, self.num_clients, + self.num_classes, self.dir_alpha, + client_sample_nums, self.verbose) + + return client_dict + + def __getitem__(self, index): + """Obtain sample indices for client ``index``. + + Args: + index (int): Client ID. + + Returns: + list: List of sample indices for client ID ``index``. + + """ + return self.client_dict[index] + + def __len__(self): + """Usually equals to number of clients.""" + return len(self.client_dict) + + + + +class CIFAR100Partitioner(CIFAR10Partitioner): + """CIFAR100 data partitioner. + + This is a subclass of the :class:`CIFAR10Partitioner`. For details, please check `Federated Dataset and DataPartitioner `_. + """ + num_classes = 100 + + +class BasicPartitioner(DataPartitioner): + """Basic data partitioner. + + Basic data partitioner, supported partition: + + - label-distribution-skew:quantity-based + + - label-distribution-skew:distributed-based (Dirichlet) + + - quantity-skew (Dirichlet) + + - IID + + For more details, please check `Federated Learning on Non-IID Data Silos: An Experimental Study `_ and `Federated Dataset and DataPartitioner `_. + + Args: + targets (list or numpy.ndarray): Sample targets. Unshuffled preferred. + num_clients (int): Number of clients for partition. + partition (str): Partition name. Only supports ``"noniid-#label"``, ``"noniid-labeldir"``, ``"unbalance"`` and ``"iid"`` partition schemes. + dir_alpha (float): Parameter alpha for Dirichlet distribution. Only works if ``partition="noniid-labeldir"``. + major_classes_num (int): Number of major class for each clients. Only works if ``partition="noniid-#label"``. + verbose (bool): Whether output intermediate information. Default as ``True``. + min_require_size (int, optional): Minimum required sample number for each client. If set to ``None``, then equals to ``num_classes``. Only works if ``partition="noniid-labeldir"``. + seed (int): Random seed. Default as ``None``. + + Returns: + dict: ``{ client_id: indices}``. + """ + num_classes = 2 + + def __init__(self, targets, num_clients, + partition='iid', + dir_alpha=None, + major_classes_num=1, + verbose=True, + min_require_size=None, + seed=None): + self.targets = np.array(targets) # with shape (num_samples,) + self.num_samples = self.targets.shape[0] + self.num_clients = num_clients + self.client_dict = dict() + self.partition = partition + self.dir_alpha = dir_alpha + self.verbose = verbose + self.min_require_size = min_require_size + + # self.rng = np.random.default_rng(seed) # rng currently not supports randint + np.random.seed(seed) + + if partition == "noniid-#label": + # label-distribution-skew:quantity-based + assert isinstance(major_classes_num, int), f"'major_classes_num' should be integer, " \ + f"not {type(major_classes_num)}." + assert major_classes_num > 0, f"'major_classes_num' should be positive." + assert major_classes_num < self.num_classes, f"'major_classes_num' for each client " \ + f"should be less than number of total " \ + f"classes {self.num_classes}." + self.major_classes_num = major_classes_num + elif partition in ["noniid-labeldir", "unbalance"]: + # label-distribution-skew:distributed-based (Dirichlet) and quantity-skew (Dirichlet) + assert dir_alpha > 0, f"Parameter 'dir_alpha' for Dirichlet distribution should be " \ + f"positive." + elif partition == "iid": + # IID + pass + else: + raise ValueError( + f"tabular data partition only supports 'noniid-#label', 'noniid-labeldir', " + f"'unbalance', 'iid'. {partition} is not supported.") + + self.client_dict = self._perform_partition() + # get sample number count for each client + self.client_sample_count = F.samples_num_count(self.client_dict, self.num_clients) + + def _perform_partition(self): + if self.partition == "noniid-#label": + # label-distribution-skew:quantity-based + client_dict = F.label_skew_quantity_based_partition(self.targets, self.num_clients, + self.num_classes, + self.major_classes_num) + + elif self.partition == "noniid-labeldir": + # label-distribution-skew:distributed-based (Dirichlet) + client_dict = F.hetero_dir_partition(self.targets, self.num_clients, self.num_classes, + self.dir_alpha, + min_require_size=self.min_require_size) + + elif self.partition == "unbalance": + # quantity-skew (Dirichlet) + client_sample_nums = F.dirichlet_unbalance_split(self.num_clients, self.num_samples, + self.dir_alpha) + client_dict = F.homo_partition(client_sample_nums, self.num_samples) + + else: + # IID + client_sample_nums = F.balance_split(self.num_clients, self.num_samples) + client_dict = F.homo_partition(client_sample_nums, self.num_samples) + + return client_dict + + def __getitem__(self, index): + return self.client_dict[index] + + def __len__(self): + return len(self.client_dict) + + +class VisionPartitioner(BasicPartitioner): + """Data partitioner for vision data. + + Supported partition for vision data: + + - label-distribution-skew:quantity-based + + - label-distribution-skew:distributed-based (Dirichlet) + + - quantity-skew (Dirichlet) + + - IID + + For more details, please check `Federated Learning on Non-IID Data Silos: An Experimental Study `_. + + Args: + targets (list or numpy.ndarray): Sample targets. Unshuffled preferred. + num_clients (int): Number of clients for partition. + partition (str): Partition name. Only supports ``"noniid-#label"``, ``"noniid-labeldir"``, ``"unbalance"`` and ``"iid"`` partition schemes. + dir_alpha (float): Parameter alpha for Dirichlet distribution. Only works if ``partition="noniid-labeldir"``. + major_classes_num (int): Number of major class for each clients. Only works if ``partition="noniid-#label"``. + verbose (bool): Whether output intermediate information. Default as ``True``. + seed (int): Random seed. Default as ``None``. + + Returns: + dict: ``{ client_id: indices}``. + + """ + num_classes = 10 + + def __init__(self, targets, num_clients, + partition='iid', + dir_alpha=None, + major_classes_num=None, + verbose=True, + seed=None): + super(VisionPartitioner, self).__init__(targets=targets, num_clients=num_clients, + partition=partition, + dir_alpha=dir_alpha, + major_classes_num=major_classes_num, + verbose=verbose, + seed=seed) + + +class MNISTPartitioner(VisionPartitioner): + """Data partitioner for MNIST. + + For details, please check :class:`VisionPartitioner` and `Federated Dataset and DataPartitioner `_. + """ + num_features = 784 + + +class FMNISTPartitioner(VisionPartitioner): + """Data partitioner for FashionMNIST. + + For details, please check :class:`VisionPartitioner` and `Federated Dataset and DataPartitioner `_ + """ + num_features = 784 + + +class SVHNPartitioner(VisionPartitioner): + """Data partitioner for SVHN. + + For details, please check :class:`VisionPartitioner` and `Federated Dataset and DataPartitioner `_ + """ + num_features = 1024 + + +class FCUBEPartitioner(DataPartitioner): + """FCUBE data partitioner. + + FCUBE is a synthetic dataset for research in non-IID scenario with feature imbalance. This + dataset and its partition methods are proposed in `Federated Learning on Non-IID Data Silos: An + Experimental Study `_. + + Supported partition methods for FCUBE: + + - feature-distribution-skew:synthetic + + - IID + + For more details, please refer to Section (IV-B-b) of original paper. For detailed usage, please check `Federated Dataset and DataPartitioner `_. + + Args: + data (numpy.ndarray): Data of dataset :class:`FCUBE`. + partition (str): Partition type. Only supports `'synthetic'` and `'iid'`. + """ + num_classes = 2 + num_clients = 4 # only accept partition for 4 clients + + def __init__(self, data, partition): + if partition not in ['synthetic', 'iid']: + raise ValueError( + f"FCUBE only supports 'synthetic' and 'iid' partition, not {partition}.") + self.partition = partition + self.data = data + if isinstance(data, np.ndarray): + self.num_samples = data.shape[0] + else: + self.num_samples = len(data) + + self.client_dict = self._perform_partition() + + def _perform_partition(self): + if self.partition == 'synthetic': + # feature-distribution-skew:synthetic + client_dict = F.fcube_synthetic_partition(self.data) + else: + # IID partition + client_sample_nums = F.balance_split(self.num_clients, self.num_samples) + client_dict = F.homo_partition(client_sample_nums, self.num_samples) + + return client_dict + + def __getitem__(self, index): + return self.client_dict[index] + + def __len__(self): + return self.num_clients + + +class AdultPartitioner(BasicPartitioner): + """Data partitioner for Adult. + + For details, please check :class:`BasicPartitioner` and `Federated Dataset and DataPartitioner `_ + """ + num_features = 123 + num_classes = 2 + + +class RCV1Partitioner(BasicPartitioner): + """Data partitioner for RCV1. + + For details, please check :class:`BasicPartitioner` and `Federated Dataset and DataPartitioner `_ + """ + num_features = 47236 + num_classes = 2 + + +class CovtypePartitioner(BasicPartitioner): + """Data partitioner for Covtype. + + For details, please check :class:`BasicPartitioner` and `Federated Dataset and DataPartitioner `_ + """ + num_features = 54 + num_classes = 2 diff --git a/The_Collaborative_Training_Phase/data/data_utils.py b/The_Collaborative_Training_Phase/data/data_utils.py new file mode 100644 index 0000000..2562715 --- /dev/null +++ b/The_Collaborative_Training_Phase/data/data_utils.py @@ -0,0 +1,178 @@ +import torch +import torchvision +import os +import sys +import numpy as np +sys.path.append("..") +from torchvision.datasets import ImageFolder +import torchvision.datasets as datasets +from data.cifar10 import cifar10_hetero_dir_part,client_cifar10 +from torch.utils.data import Dataset +from data.cifar100 import cifar100_hetero_dir_part,client_cifar100 +import torchvision.transforms as T +METRICS = ['glob_round','glob_acc', 'per_acc', 'glob_loss', 'per_loss', 'client_train_time', 'server_agg_time'] + +def data_division(args,iid_flag=False,balance=None,partition="dirichlet",dir_alpha=0.3): + + data_save_path = "./data/data_" + args.dataset + if args.dataset == "cifar10": + if iid_flag: + pass + else: + cifar10_hetero_dir_part(_print=args._print,seed=args.seed,dataset_path=args.dataset_path,save_path=data_save_path,num_clients=args.total_clients,balance=balance,partition=partition,dir_alpha=dir_alpha) + elif args.dataset == "cifar100": + if iid_flag: + pass + else: + cifar100_hetero_dir_part(_print=args._print,seed=args.seed,dataset_path=args.dataset_path,save_path=data_save_path,num_clients=args.total_clients,balance=balance,partition=partition,dir_alpha=dir_alpha) + else: + pass + + +def arrange_data_byclass(testdata,Class_nm): + + testdatasamples, testdatalabels = [],[] + for x, y in testdata: + testdatasamples.append(torch.unsqueeze(x, dim=0)) + testdatalabels.append(y) + + testdatasamples_all = torch.cat(testdatasamples, dim=0) + # print(testdatasamples_all.shape) + + indices_class = [[] for c in range(Class_nm)] + for i, c in enumerate(testdatalabels): + indices_class[c].append(i) + + def get_images(c): # get random n images from class c + idx_shuffle = np.random.permutation(indices_class[c])[:] + return testdatasamples_all[idx_shuffle] + + data_by_class = {} + for c in range(Class_nm): + imgs = get_images(c) + # print(imgs.shape) + if imgs.shape[0] != 0: + labels = torch.ones((imgs.shape[0],), dtype=torch.long) * c + data_by_class[c]=[(x, y) for x, y in zip(imgs, labels)] + + return data_by_class + +def get_global_test_data(args): + + transforms = prepare_transforms(args.dataset,args.dataset_mean,args.dataset_std) + dataset_path = args.dataset_path + if args.dataset == "cifar10": + testdata = torchvision.datasets.CIFAR10(root=dataset_path, + train=False, download=False ,transform=transforms['test']) + elif args.dataset == "cifar100": + testdata = torchvision.datasets.CIFAR100(root=dataset_path, + train=False, download=False ,transform=transforms['test']) + + data_by_class = arrange_data_byclass(testdata,args.model_outputdim) + return testdata,data_by_class + + +class TensorDataset(Dataset): + def __init__(self, images, labels): # images: n x c x h x w tensor + self.images = images.detach().float() + self.labels = labels.detach() + + def __getitem__(self, index): + return self.images[index], torch.tensor(self.labels[index]+10) + + def __len__(self): + return self.images.shape[0] + def appendd(self,images,labels): + self.images = torch.cat([self.images,images.detach().float()],dim=0) + self.labels = torch.cat([self.labels,labels.detach().float()],dim=0) + + +class CustomImageFolder_cifar10(datasets.ImageFolder): + def __getitem__(self, index): + img, target = super(CustomImageFolder_cifar10, self).__getitem__(index) + target += 10 + return img, target + +class CustomImageFolder_cifar100(datasets.ImageFolder): + def __getitem__(self, index): + img, target = super(CustomImageFolder_cifar100, self).__getitem__(index) + target += 100 + return img, target + + +def get_gan_dataset(args,images,labels): + transforms = prepare_transforms(args.dataset,args.dataset_mean,args.dataset_std) + gan_dataset = TensorDataset(images,labels) + return gan_dataset + +def read_gan_data(args,dir,num_classes): + transforms = prepare_transforms(args.dataset,args.dataset_mean,args.dataset_std) + if args.model_outputdim == 10: + gan_dataset = CustomImageFolder_cifar10(dir,transform=transforms['tensor_dataset']) + elif args.model_outputdim == 100: + gan_dataset = CustomImageFolder_cifar100(dir, transform=transforms['tensor_dataset']) + return gan_dataset + +def read_client_data(args,id,transform=True): + + transforms = prepare_transforms(args.dataset,args.dataset_mean,args.dataset_std) + data_save_path = "./data/data_"+args.dataset + if transform: + if args.dataset == "cifar10": + traindata = client_cifar10(indices=id, mode='train', data_root_dir=data_save_path, transform=transforms['train']) + testdata = client_cifar10(indices=id, mode='test', data_root_dir=data_save_path, transform=transforms['test']) + elif args.dataset == "cifar100": + traindata = client_cifar100(indices=id, mode='train', data_root_dir=data_save_path, transform=transforms['train']) + testdata = client_cifar100(indices=id, mode='test', data_root_dir=data_save_path, transform=transforms['test']) + else: + pass + return traindata,testdata + else: + if args.dataset == "cifar10": + traindata = client_cifar10(indices=id, mode='train', data_root_dir=data_save_path, transform=transforms['feature_match']) + elif args.dataset == "cifar100": + traindata = client_cifar100(indices=id, mode='train', data_root_dir=data_save_path, transform=transforms['feature_match']) + else: + pass + return traindata + +def prepare_transforms(datasetname,data_mean,data_std): + + normalize = T.Normalize(mean=data_mean, std=data_std) + img_size = 32 + transforms = {'train': T.Compose([ + T.RandomHorizontalFlip(), + T.RandomCrop(img_size, padding=4), + # T.ToPILImage(), + # T.Resize((32,32)), + T.ToTensor(), + normalize]), + 'test': T.Compose([ + # T.ToPILImage(), + # T.Resize((32,32)), + T.ToTensor(), + normalize]), + 'infer': T.Compose([ + # T.ToPILImage(), + # T.Resize((32,32)), + T.ToTensor(), + normalize]), + 'feature_match':T.Compose([ + T.Resize((img_size,img_size)), + T.ToTensor(), + normalize]), + 'tensor_dataset':T.Compose([ + T.Resize((img_size,img_size)), + T.ToTensor(), + normalize + ]) + } + return transforms + + +def normalization(data): + _range = np.max(data) - np.min(data) + return (data - np.min(data)) / _range + + + \ No newline at end of file diff --git a/The_Collaborative_Training_Phase/flalgorithms/__init__.py b/The_Collaborative_Training_Phase/flalgorithms/__init__.py new file mode 100644 index 0000000..ed7cc8f --- /dev/null +++ b/The_Collaborative_Training_Phase/flalgorithms/__init__.py @@ -0,0 +1 @@ +from .fedslaug import FedSlaug_Server as FedSlaug diff --git a/The_Collaborative_Training_Phase/flalgorithms/clientbase.py b/The_Collaborative_Training_Phase/flalgorithms/clientbase.py new file mode 100644 index 0000000..8130e6c --- /dev/null +++ b/The_Collaborative_Training_Phase/flalgorithms/clientbase.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import json +from torch.utils.data import DataLoader +import numpy as np +import copy + + +class Client: + def __init__( + self, args, id, model, train_data, test_data, use_adam=False): + self.device = args.device + self.model = copy.deepcopy(model) + self.model_name = self.model.model_name + self.id = id # integer + self.train_data_size = train_data.src_datalen + self.test_data_size = test_data.src_datalen + self.batch_size = args.batch_size + self.learning_rate = args.learning_rate + self.lr_decay = args.lr_decay + self.weight_decay = args.weight_decay + self.local_epochs = args.local_epochs + self.algorithm = args.algorithm + self.num_classes = args.model_outputdim + self.dataset = args.dataset + self.train_data = train_data + self.test_data = test_data + self.trainloader = DataLoader(self.train_data, self.batch_size, shuffle=True, drop_last=False) + self.testloader = DataLoader(self.test_data, self.batch_size, shuffle=False, drop_last=False) + self.trainloaderfull = DataLoader(self.train_data, self.train_data_size, shuffle=False) + self.testloaderfull = DataLoader(self.test_data, self.test_data_size, shuffle=True) + self.label_counts = {} + self.init_loss_fn() + + if use_adam: + self.optimizer=torch.optim.Adam( + params=self.model.parameters(), + lr=self.learning_rate, betas=(0.9, 0.999), + eps=1e-08, weight_decay=1e-2, amsgrad=False) + else: + self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, momentum=0.9) + self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=self.lr_decay) + + + def init_loss_fn(self): + self.loss=nn.NLLLoss() + self.dist_loss = nn.MSELoss() + self.ensemble_loss=nn.KLDivLoss(reduction="batchmean") + self.ce_loss = nn.CrossEntropyLoss() + + + def get_parameters(self): + for param in self.model.parameters(): + param.detach() + return self.model.parameters() + + + def get_grads(self): + grads = [] + for param in self.model.parameters(): + if param.grad is None: + grads.append(torch.zeros_like(param.data)) + else: + grads.append(param.grad.data) + return grads + + + def test(self): + self.model.to(self.device) + self.model.eval() + test_acc = 0 + loss = 0 + for data in self.testloaderfull: + x, y = data[0].to(self.device), data[1].to(self.device) + output = self.model(x) + loss += self.ce_loss(output, y).detach().cpu().item() + test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item() + self.model.to("cpu") + return test_acc, loss, y.shape[0] + + + def save_model(self,round): + model_path = self.dataset + "_" + self.algorithm + model_path += "_" + str(self.learning_rate) + "lr" + "_" + str(self.num_clients_per_round) \ + + "ncpr" + "_" + str(self.batch_size) + "bs" + "_" + str(self.local_epochs) + "le" + "_" + str(self.seed) + "s" + model_path = os.path.join(self.save_path, "models", model_path) + if not os.path.exists(model_path): + os.makedirs(model_path) + torch.save(self.model, os.path.join(model_path, "user_" + self.id + ".pth")) + + def load_model(self): + model_path = self.dataset + "_" + self.algorithm + model_path += "_" + str(self.learning_rate) + "lr" + "_" + str(self.num_clients_per_round) \ + + "ncpr" + "_" + str(self.batch_size) + "bs" + "_" + str(self.local_epochs) + "le" + "_" + str(self.seed) + "s" + model_path = os.path.join(self.save_path, "models", model_path) + assert (os.path.exists(model_path)) + self.model = torch.load(os.path.join(model_path, "server" + ".pth")) + diff --git a/The_Collaborative_Training_Phase/flalgorithms/fedslaug.py b/The_Collaborative_Training_Phase/flalgorithms/fedslaug.py new file mode 100644 index 0000000..79fda8b --- /dev/null +++ b/The_Collaborative_Training_Phase/flalgorithms/fedslaug.py @@ -0,0 +1,271 @@ +from .clientbase import Client +from .serverbase import Server +from data.data_utils import get_global_test_data,data_division,read_client_data,read_gan_data +import numpy as np +import time +from torch.utils.data import DataLoader +from tqdm import tqdm +import torch +import copy +import os +from torchvision.utils import save_image +from torch.utils.data import ConcatDataset +from models.acgan import Model + + +class FedSlaug_Server(Server): + def __init__(self, args, model, seed): + + data_division(args,balance=None,partition=args.partition_method,dir_alpha=args.dirichlet_alpha) + test_data_global,test_data_by_class = get_global_test_data(args) + self.test_data_by_class = test_data_by_class + super().__init__(args, model, test_data_global, seed) + + for id in range(args.total_clients): + train_data, test_data = read_client_data(args,id=id) + val_data = read_client_data(args,id=id, transform=False) + client = FedSlaug_Client(args, id, model, train_data, test_data, val_data,use_adam=False) + self.clients.append(client) + + self.acgan_model = None + self._print("clients_per_round/total_clients:{}/{}".format(args.num_clients_per_round,args.total_clients)) + self._print("Finished creating FedSlaug server.") + + + def update_feature(self,server_class_feature,server_class_num,args): + for label in range(args.model_outputdim): + server_class_feature[0][label] = server_class_feature[0][label]*server_class_num[0][label] + for cli in range(1,len(server_class_num)): + for label in range(args.model_outputdim): + server_class_feature[0][label] += server_class_feature[cli][label]*server_class_num[cli][label] + server_class_num[0][label] += server_class_num[cli][label] + for label in range(args.model_outputdim): + server_class_feature[0][label] = server_class_feature[0][label]/server_class_num[0][label] + self.class_features = torch.stack(server_class_feature[0]) + + + def train(self, args): + if self.acgan_model == None: #load gan mdoel + CIFAR_classes = args.model_outputdim + CIFAR_image_size = 32 + CIFAR_channel = 3 + dataset_info = (CIFAR_classes, CIFAR_image_size, CIFAR_channel) + if args.model_outputdim == 10: + self.acgan_model = Model(*dataset_info, None, args, device=self.device) + self.acgan_model.load_model(args.GAN_name, args.GAN_dir) + elif args.model_outputdim == 100: + if args.GAN_type == 0: + self.acgan_model = Model(*dataset_info, None, args, device=self.device) + self.acgan_model.load_model(args.GAN_name, args.GAN_dir) + else: + self.acgan_model = [Model(*dataset_info, None, args, device=self.device) for i in range(10)] + for i in range(10): + self.acgan_model[i].load_model(args.GAN_name, os.path.join(args.GAN_dir,"cifar100_{}".format(i))) + save_path = os.path.join(args.save_path,'gen_img') + if not os.path.isdir(save_path): + os.makedirs(save_path,exist_ok=True) + + self.class_features = None + for glob_round in range(1,self.num_glob_rounds+1): + self._print("-------------Round number:{}-------------".format(glob_round)) + + #Send latest model to clients + self.selected_clients,client_idxs = self.select_clients(glob_round,self.num_clients_per_round,return_idx=True) + self._print("Clients selected in this round:{}".format(client_idxs)) + + if glob_round % self.eval_every == 0: + self._print("Start test at Glob_round_{}".format(glob_round)) + self.test_global_model(args,glob_round=glob_round,test_each_classacc=True) + self._print("Glob_round_{} test done".format(glob_round)) + + # Train selected local clients + self.timestamp = time.time() # log client-training start time + w_local = {} + self._print("client model training start.") + server_class_feature = [] + server_class_num = [] + for client in tqdm(self.selected_clients): # allow selected clients to train + client.model.load_state_dict(self.latest_model_params) + w,features,feature_num = client.train(glob_round,self.acgan_model,self.class_features) + server_class_feature.append(features) + server_class_num.append(feature_num) + w_local[client.id] = w + self._print("client model training done.") + # Record selected clients training time + curr_timestamp = time.time() + train_time = (curr_timestamp - self.timestamp) / len(self.selected_clients) + self.metrics['client_train_time'].append(train_time) + + # Update models + self.timestamp = time.time() # log server-agg start time + self.latest_model_params = self.aggregate_parameters(w_local) # only aggregate the selected clients + self.update_feature(server_class_feature,server_class_num,args) + + curr_timestamp=time.time() # log server-agg end time + agg_time = curr_timestamp - self.timestamp + self.metrics['server_agg_time'].append(agg_time) + + # Save final results + self.save_results(args) + # self.save_model(glob_round) + max_acc = max(self.metrics['glob_acc']) + max_acc_index = self.metrics['glob_acc'].index(max_acc) + mac_acc_round = self.metrics['glob_round'][max_acc_index] + self._print("Max glob_acc in this time is {} at round {}.".format(max_acc,mac_acc_round)) + + def test_global_model(self, args, glob_round=0, save=True, test_each_classacc=False): + # global test results + labels_acc = {} + labels_loss = {} + self.model.load_state_dict(self.latest_model_params) + self.model.to(self.device) + self.model.eval() + # test global_accuracy + correct, glob_loss = self.global_test(self.global_testloader) + glob_acc = (correct * 1.0) / self.global_test_size + # test each class accuracy + if test_each_classacc: + for label, data_arranged in self.test_data_by_class.items(): + class_data = DataLoader(data_arranged, self.batch_size, drop_last=False) + correct, test_loss = self.global_test(class_data) + labels_acc[label] = (correct * 1.0) / len(data_arranged) + labels_loss[label] = test_loss + self.model.to("cpu") + if save: + self.metrics['glob_acc'].append(glob_acc) + self.metrics['glob_loss'].append(glob_loss) + self.metrics['glob_round'].append(glob_round) + self._print("The Average Global Accurancy = {:.4f}, Loss = {:.2f} at Glob_round_{}.".format(glob_acc, glob_loss, glob_round)) + self._print("Acc of each class:{}.".format(labels_acc)) + + def global_test(self,dataset): + correct = 0 + test_loss = 0 + for data in dataset: + x, y = data[0].to(self.device), data[1].to(self.device) + output = self.model(x) + test_loss += self.ce_loss_sum(output, y).detach().cpu().item() # sum up batch loss + correct += (torch.sum( torch.argmax(output, dim=1) == y)).item() + test_loss /= self.global_test_size + return correct, test_loss + + + +class FedSlaug_Client(Client): + def __init__(self, args, id, model, train_data, test_data, val_data, use_adam=False): + super().__init__(args, id, model, train_data, test_data, use_adam=use_adam) + self.val_data = val_data + self.args=args + self.save_path = os.path.join(args.save_path,'gen_img',str(self.id)) + if not os.path.isdir(self.save_path): + os.makedirs(self.save_path,exist_ok=True) + self.class_save_path = [] + for i in range(self.num_classes): + class_save_path = os.path.join(self.save_path,str(i)) + if not os.path.isdir(class_save_path): + os.makedirs(class_save_path,exist_ok=True) + self.class_save_path.append(class_save_path) + + self.combined_dataset = None + self.gan_dataset = None + self.combined_dataloader = None + def train(self, glob_round, acgan_model,class_features=None,lr_decay=True): + + self.model.to(self.device) + self.model.train() + #generate fake images: + if self.combined_dataset is None: + gen_num = self.train_data_size//self.num_classes + + for target in range(self.num_classes): + target_labels = [target for i in range(gen_num)] + if self.num_classes == 10: + gen_imgs = acgan_model.generate_data(target_labels) + elif self.num_classes == 100: + if self.args.GAN_type == 0: + gen_imgs = acgan_model.generate_data(target_labels) + else: + gan_num = target//10 + gen_imgs = acgan_model[gan_num].generate_data(target_labels) + for i in range(gen_num): + save_image(gen_imgs.data[i],self.class_save_path[target]+'/'+str(i)+'.png',normalize=True) + self.gan_dataset = read_gan_data(self.args,self.save_path,self.num_classes) + self.combined_dataset = ConcatDataset([self.train_data, self.gan_dataset]) + self.combined_dataloader = DataLoader(self.combined_dataset,self.batch_size,shuffle=True) + + my_class_feature = [[] for i in range(self.num_classes)] + class_num = [0 for i in range(self.num_classes)] + for epoch in range(1, self.local_epochs + 1): + for data in self.combined_dataloader: + x, y = data[0].to(self.device), data[1].to(self.device) + self.optimizer.zero_grad() + output,feature,_,_,_,_=self.model(x,out_feature=True) + output1=[] + output2=[] + y1=[] + y2=[] + for data,label in zip(output,y): + if label< self.num_classes: + output1.append(data) + y1.append(label) + elif label >= self.num_classes: + output2.append(data) + y2.append(label) + if output1 != [] and output2 !=[]: + output1=torch.stack(output1) + output2=torch.stack(output2) + y1=torch.stack(y1) + y2=torch.stack(y2) + loss1 = self.ce_loss(output1, y1)+0.5*self.ce_loss(output2, y2) #loss1 + elif output1==[] and output2 !=[]: + output2=torch.stack(output2) + y2=torch.stack(y2) + loss1 = 0.5*self.ce_loss(output2, y2) #loss1 + elif output1!=[] and output2 == []: + output1=torch.stack(output1) + y1=torch.stack(y1) + loss1 = self.ce_loss(output1, y1) #loss1 + if class_features is not None: + new_target = torch.where(y >= self.num_classes, y - self.num_classes, y).to(self.device) + feature_ = [[] for i in range(self.args.model_outputdim)] + feature_now = [] + mylabels = [] + for ss in range(len(y)): + label_tmp = new_target[ss] + feature_[label_tmp].append(feature[ss]) + for label_tmp in range(self.args.model_outputdim): + if feature_[label_tmp] != []: + mylabels.append(label_tmp) + feature_now.append((torch.mean(torch.stack(feature_[label_tmp]),dim=0))) + feature_now = torch.stack(feature_now) + mylabels = torch.tensor(mylabels).to(self.device) + + cos_sim = torch.cosine_similarity(feature_now.unsqueeze(1), class_features.unsqueeze(0), dim=-1) + cos_sim = cos_sim/0.5 #tempreture=0.5 + loss2 = 1.0 * self.ce_loss(cos_sim, mylabels) + else: + loss2 = 0 + + loss = loss1 + loss2 + loss.backward() + self.optimizer.step() + + gan_loader = DataLoader(self.gan_dataset,self.batch_size,shuffle=True) + self.model.eval() + for data in gan_loader: + x, y = data[0].to(self.device), data[1].to(self.device) + output,feature,_,_,_,_=self.model(x,out_feature=True) + for i in range(y.shape[0]): + label = y[i].item()-self.num_classes + if 0 <= label < self.num_classes: + my_class_feature[label].append(feature[i].clone().detach()) + + for i in range(self.num_classes): + class_num[i] = len(my_class_feature[i]) + my_class_feature[i] = torch.stack(my_class_feature[i]) + class_mean = my_class_feature[i].mean(dim=0) + my_class_feature[i] = class_mean + + self.model.to("cpu") + return copy.deepcopy(self.model.state_dict()),my_class_feature ,class_num + diff --git a/The_Collaborative_Training_Phase/flalgorithms/serverbase.py b/The_Collaborative_Training_Phase/flalgorithms/serverbase.py new file mode 100644 index 0000000..b6958f4 --- /dev/null +++ b/The_Collaborative_Training_Phase/flalgorithms/serverbase.py @@ -0,0 +1,149 @@ +import torch +import os +import numpy as np +import h5py +import copy +import torch.nn.functional as F +import time +import torch.nn as nn +from torch.utils.data import DataLoader +from data.data_utils import METRICS +from tqdm import tqdm +from collections import OrderedDict + +class Server: + def __init__(self, args, model, global_test_data, seed): + + # Set up the main attributes + self.device = args.device + self.eval_every = args.eval_every + self.save_every = args.save_every + self.dataset = args.dataset + self.batch_size = args.batch_size + self.global_test_size = len(global_test_data) + self.global_test_data = global_test_data + self.global_testloader = DataLoader(self.global_test_data, self.batch_size, shuffle=False, drop_last=False) + self.num_glob_rounds = args.num_glob_rounds + self.local_epochs = args.local_epochs + self.learning_rate = args.learning_rate + self.total_train_samples = 0 + self.model = model + self.model_name = self.model.model_name + self.latest_model_params = self.model.state_dict() + self.clients = [] + self.selected_clients = [] + self.num_clients_per_round = args.num_clients_per_round + self.algorithm = args.algorithm + self.seed = seed + self.metrics = {key:[] for key in METRICS} + self.timestamp = None + self.save_path = args.save_path + self._print = args._print + self.init_loss_fn() + + def send_parameters(self, beta=1, selected=False): + clients = self.clients + if selected: + assert (self.selected_clients is not None and len(self.selected_clients) > 0) + clients = self.selected_clients + for client in clients: + client.model.load_state_dict(self.latest_model_params) + + def aggregate_parameters(self, w_local_map): + assert (self.selected_clients is not None and len(self.selected_clients) > 0) + model_param_container = OrderedDict() + for key, val in self.latest_model_params.items(): + model_param_container[key] = torch.zeros_like(val) + total_train = 0 + for client in self.selected_clients: + total_train += client.train_data_size + for key, val in w_local_map[client.id].items(): + model_param_container[key] += val*client.train_data_size + for key, val in model_param_container.items(): + model_param_container[key] = val/total_train + + return model_param_container + + + def check_param(self): + server_model_dict = self.model.state_dict() + for client in self.clients: + client_model_dict = client.model.state_dict() + for name in server_model_dict: + if not torch.equal(server_model_dict[name].data,client_model_dict[name].data): + print(client.id) + + + def save_results(self, args): + alg = self.dataset + "_" + self.algorithm + alg += "_" + str(self.learning_rate) + "lr" + "_" + str(self.num_clients_per_round) \ + + "ncpr" + "_" + str(self.batch_size) + "bs" + "_" + str(self.local_epochs) + "le" + "_" + str(self.seed) + "s" + records_path = os.path.join(self.save_path, "records") + if not os.path.exists(records_path): + os.makedirs(records_path) + with h5py.File("{}/{}.h5".format(records_path, alg), 'w') as hf: + for key in self.metrics: + hf.create_dataset(key, data=self.metrics[key]) + hf.close() + + + def save_model(self,round): + model_path = self.dataset + "_" + self.algorithm + model_path += "_" + str(self.learning_rate) + "lr" + "_" + str(self.num_clients_per_round) \ + + "ncpr" + "_" + str(self.batch_size) + "bs" + "_" + str(self.local_epochs) + "le" + "_" + str(self.seed) + "s" + model_path = os.path.join(self.save_path, "models", model_path) + if not os.path.exists(model_path): + os.makedirs(model_path) + + if round == self.num_glob_rounds: + model_save_name = os.path.join(model_path, "server" + ".pth") + else: + model_save_name = os.path.join(model_path, "server_{}".format(round) + ".pth") + self.model.save(model_save_name) + + + def load_model(self): + model_path = self.dataset + "_" + self.algorithm + model_path += "_" + str(self.learning_rate) + "lr" + "_" + str(self.num_clients_per_round) \ + + "ncpr" + "_" + str(self.batch_size) + "bs" + "_" + str(self.local_epochs) + "le" + "_" + str(self.seed) + "s" + model_path = os.path.join(self.save_path, "models", model_path) + assert (os.path.exists(model_path)) + self.model = self.model.load(model_path) + + + def select_clients(self, round, num_clients_per_round, return_idx=False): + if(num_clients_per_round == len(self.clients)): + self._print("All clients are selected") + return self.clients + + num_clients_per_round = min(num_clients_per_round, len(self.clients)) + if return_idx: + client_idxs = np.random.choice(range(len(self.clients)), num_clients_per_round, replace=False) + return [self.clients[i] for i in client_idxs], client_idxs + else: + return np.random.choice(self.clients, num_clients_per_round, replace=False) + + + def init_loss_fn(self): + self.loss=nn.NLLLoss() + self.ensemble_loss=nn.KLDivLoss(reduction="batchmean") + self.ce_loss_sum = nn.CrossEntropyLoss(reduction='sum') + self.ce_loss_mean = nn.CrossEntropyLoss() + self.ce_loss = nn.CrossEntropyLoss(reduction='none') + self.mse_loss_sum = torch.nn.MSELoss(reduction='sum') + self.mse_loss = torch.nn.MSELoss(reduction='none') + + + def test(self, selected=False): + num_samples = [] + tot_correct = [] + losses = [] + clients = self.selected_clients if selected else self.clients + for c in tqdm(clients): + ct, c_loss, ns = c.test() + tot_correct.append(ct*1.0) + num_samples.append(ns) + losses.append(c_loss) + ids = [c.id for c in self.clients] + + return ids, num_samples, tot_correct, losses \ No newline at end of file diff --git a/The_Collaborative_Training_Phase/main.py b/The_Collaborative_Training_Phase/main.py new file mode 100644 index 0000000..bb1fddd --- /dev/null +++ b/The_Collaborative_Training_Phase/main.py @@ -0,0 +1,57 @@ +from config import flargs +from config import print_parameters +import os +from utils.log import init_log +from datetime import datetime +import shutil +import models +import torch +from flalgorithms import FedSlaug +from datetime import datetime +import numpy as np + +def main(): + + flargs.device = torch.device('cuda:{}'.format(flargs.gpu_num) if torch.cuda.is_available() else 'cpu') + func = flargs.function + flargs.save_path = os.path.join(flargs.save_path, flargs.algorithm, datetime.now().strftime('%Y%m%d_%H%M%S')) + os.makedirs(flargs.save_path) + log = init_log(flargs.save_path) + flargs._print = log.info + print_parameters(flargs._print, flargs) + try: + eval(func)() + except Exception as e: + shutil.rmtree(flargs.save_path) + print(e) + raise + +def create_model(model_name,model_outputdim,_print,algorithm): + target_model = getattr(models, model_name)(num_classes=model_outputdim*2) + return target_model + + +def create_server_and_clients(args, i): + #creat model and dataset + model = create_model(args.model_name, args.model_outputdim, args._print, args.algorithm) + server = FedSlaug(args, model, i) + return server + + +def run_job(args=flargs): + start_time = datetime.now() + for i in range(args.repeat_times): + current_seed = args.seed + i + torch.manual_seed(current_seed) + torch.cuda.manual_seed(current_seed) + np.random.seed(current_seed) + flargs._print("--------------Start_training_iteration_{}--------------".format(i)) + server = create_server_and_clients(args, current_seed) + if args.train: + server.train(args) + end_time = datetime.now() + flargs._print("total time used: " + str((end_time - start_time))) + + +if __name__=='__main__': + main() \ No newline at end of file diff --git a/The_Collaborative_Training_Phase/models/BasicModule.py b/The_Collaborative_Training_Phase/models/BasicModule.py new file mode 100644 index 0000000..1803b45 --- /dev/null +++ b/The_Collaborative_Training_Phase/models/BasicModule.py @@ -0,0 +1,37 @@ +#coding:utf8 +import torch as t +import time + + +class BasicModule(t.nn.Module): + def __init__(self): + super(BasicModule,self).__init__() + self.model_name=str(type(self))# 默认名字 + + def load(self, path): + self.load_state_dict(t.load(path)) + + def save(self, name=None): + if name is None: + prefix = './' + self.model_name + '_' + name = time.strftime(prefix + '%m%d_%H:%M:%S.pth') + t.save(self.state_dict(), name) + return name + + +class Flat(t.nn.Module): + def __init__(self): + super(Flat, self).__init__() + #self.size = size + + def forward(self, x): + return x.view(x.size(0), -1) + +def print_network(model, _print): + """Print out the network information.""" + num_params = 0 + for p in model.parameters(): + num_params += p.numel() + # _print(name) + _print("The number of parameters of this model: {}".format(num_params)) + _print(model) \ No newline at end of file diff --git a/The_Collaborative_Training_Phase/models/__init__.py b/The_Collaborative_Training_Phase/models/__init__.py new file mode 100644 index 0000000..d4cef0c --- /dev/null +++ b/The_Collaborative_Training_Phase/models/__init__.py @@ -0,0 +1,2 @@ +from .resnet import ResNet18,ResNet34 +from .BasicModule import print_network diff --git a/The_Collaborative_Training_Phase/models/acgan.py b/The_Collaborative_Training_Phase/models/acgan.py new file mode 100644 index 0000000..d43ef54 --- /dev/null +++ b/The_Collaborative_Training_Phase/models/acgan.py @@ -0,0 +1,617 @@ +import math +import os +import numpy as np +import torch +import torch.nn as nn +from torch.autograd import Variable +from torch.utils.data import DataLoader +from torchvision.utils import save_image +from models.BasicModule import BasicModule +from models.base import BaseModel + + +class Generator_strong(nn.Module): + def __init__(self, n_classes, image_size, channels=3, latent_dim=100): + super(Generator_strong, self).__init__() + self.label_emb = nn.Embedding(n_classes, latent_dim) + + # first linear layer + self.fc1 = nn.Linear(latent_dim, 384) + # Transposed Convolution 2 + self.tconv2 = nn.Sequential( + nn.ConvTranspose2d(384, 192, 4, 1, 0, bias=False), + nn.BatchNorm2d(192), + nn.ReLU(True), + ) + # Transposed Convolution 3 + self.tconv3 = nn.Sequential( + nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False), + nn.BatchNorm2d(96), + nn.ReLU(True), + ) + # Transposed Convolution 4 + self.tconv4 = nn.Sequential( + nn.ConvTranspose2d(96, 48, 4, 2, 1, bias=False), + nn.BatchNorm2d(48), + nn.ReLU(True), + ) + # Transposed Convolution 4 + self.tconv5 = nn.Sequential( + nn.ConvTranspose2d(48, 3, 4, 2, 1, bias=False), + nn.Tanh(), + ) + + def forward(self, noise, labels): + gen_input = torch.mul(self.label_emb(labels), noise) + fc1 = self.fc1(gen_input) + fc1 = fc1.view(-1, 384, 1, 1) + tconv2 = self.tconv2(fc1) + tconv3 = self.tconv3(tconv2) + tconv4 = self.tconv4(tconv3) + tconv5 = self.tconv5(tconv4) + return tconv5 + + def get_params(self): + return self.state_dict() + + def set_params(self, model_params): + self.load_state_dict(model_params) + + +class Discriminator_strong(nn.Module): + def __init__(self, n_classes, image_size, channels=1): + super(Discriminator_strong, self).__init__() + + # Convolution 1 + self.conv1 = nn.Sequential( + nn.Conv2d(3, 16, 3, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout(0.5, inplace=False), + ) + # Convolution 2 + self.conv2 = nn.Sequential( + nn.Conv2d(16, 32, 3, 1, 1, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout(0.5, inplace=False), + ) + # Convolution 3 + self.conv3 = nn.Sequential( + nn.Conv2d(32, 64, 3, 2, 1, bias=False), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout(0.5, inplace=False), + ) + # Convolution 4 + self.conv4 = nn.Sequential( + nn.Conv2d(64, 128, 3, 1, 1, bias=False), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout(0.5, inplace=False), + ) + # Convolution 5 + self.conv5 = nn.Sequential( + nn.Conv2d(128, 256, 3, 2, 1, bias=False), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout(0.5, inplace=False), + ) + # Convolution 6 + self.conv6 = nn.Sequential( + nn.Conv2d(256, 512, 3, 1, 1, bias=False), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout(0.5, inplace=False), + ) + # discriminator fc + self.fc_dis = nn.Linear(4 * 4 * 512, 1) + # aux-classifier fc + self.fc_aux = nn.Linear(4 * 4 * 512, n_classes) + # softmax and sigmoid + self.softmax = nn.Softmax() + self.sigmoid = nn.Sigmoid() + + def forward(self, img): + conv1 = self.conv1(img) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + conv4 = self.conv4(conv3) + conv5 = self.conv5(conv4) + conv6 = self.conv6(conv5) + flat6 = conv6.view(-1, 4 * 4 * 512) + fc_dis = self.fc_dis(flat6) + fc_aux = self.fc_aux(flat6) + classes = self.softmax(fc_aux) + realfake = self.sigmoid(fc_dis).view(-1, 1).squeeze(1) + + return realfake, classes + + def get_params(self): + return self.state_dict() + + def set_params(self, model_params): + self.load_state_dict(model_params) + + +class Generator_weak(BasicModule): + def __init__(self, device, n_classes, image_size, channels=3, noise_dim=100): + super(Generator_weak, self).__init__() + self.model_name = 'Generator' + self.device = device + self.n_classes = n_classes + self.noise_dim = noise_dim + self.init_size = image_size // 4 + self.input_emb = nn.Sequential(nn.Linear(self.n_classes, 4096)) + self.noise_emb = nn.Sequential(nn.Linear(self.noise_dim, 4096)) + + self.conv_blocks = nn.Sequential( + nn.BatchNorm2d(128), + nn.Conv2d(128, 128, 3, stride=1, padding=1), + nn.BatchNorm2d(128, 0.8), + nn.LeakyReLU(0.2, inplace=True), + nn.Upsample(scale_factor=2), + nn.Conv2d(128, 64, 3, stride=1, padding=1), + nn.BatchNorm2d(64, 0.8), + nn.LeakyReLU(0.2, inplace=True), + nn.Upsample(scale_factor=2), + nn.Conv2d(64, channels, 3, stride=1, padding=1), + nn.Tanh(), + ) + + def forward(self, noise,labels): + one_hot_labels = self.get_one_hot(labels,self.n_classes) + embedded_input = self.input_emb(one_hot_labels) + embedded_noise = self.noise_emb(noise) + z = torch.cat((embedded_noise, embedded_input), dim=1) + z = z.view(z.shape[0], 128, self.init_size, self.init_size) + img = self.conv_blocks(z) + return img + + def get_one_hot(self, target, num_class): + one_hot=torch.zeros(target.shape[0],num_class).cuda(self.device) + one_hot=one_hot.scatter(dim=1,index=target.long().view(-1,1),value=1.) + return one_hot + + def get_params(self): + return self.state_dict() + + def set_params(self, model_params): + self.load_state_dict(model_params) + +class Discriminator_weak(nn.Module): + def __init__(self, n_classes, image_size, channels=1): + super(Discriminator_weak, self).__init__() + + def discriminator_block(in_filters, out_filters, bn=True): + """Returns layers of each discriminator block""" + block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] + if bn: + block.append(nn.BatchNorm2d(out_filters, 0.8)) + return block + + self.conv_blocks = nn.Sequential( + *discriminator_block(channels, 16, bn=False), + *discriminator_block(16, 32), + *discriminator_block(32, 64), + *discriminator_block(64, 128), + ) + + # The height and width of downsampled image + ds_size = math.ceil(image_size / 2 ** 4) + + # Output layers + self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) + self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, n_classes), nn.Softmax()) + + def forward(self, img): + out = self.conv_blocks(img) + out = out.view(out.shape[0], -1) + validity = self.adv_layer(out) + label = self.aux_layer(out) + + return validity, label + + def get_params(self): + return self.state_dict() + + def set_params(self, model_params): + self.load_state_dict(model_params) + + +default_opt = { + "lr": 0.0002, # adam: learning rate + "b1": 0.5, # adam: decay of first order momentum of gradient + "b2": 0.999, # adam: decay of first order momentum of gradient +} + + +class Model(BaseModel): + """ + AC-GAN model + :param dataset_info: tuple of dataset info (num_classes, image_size, channels) + :param optimizer: a model optimizer (This optimizer is ignored) + :param device: model device + """ + + def __init__(self, num_classes, image_size, channels, optimizer,args, device=None): + super(Model, self).__init__(num_classes, None, device) + self.size = 0 + self.num_classes = num_classes + self.image_size = image_size + self.channels = channels + if args.model_outputdim == 10 and args.GAN_type==0: + self.latent_dim = 100 + elif args.model_outputdim == 10 and args.GAN_type==1: + self.latent_dim = 110 + elif args.model_outputdim == 100 and args.GAN_type==0: + self.latent_dim = 256 + + self.device = 'cpu' if not device else device + + # Loss functions + self.adversarial_loss = torch.nn.BCELoss() + self.auxiliary_loss = torch.nn.CrossEntropyLoss() + # Initialize generator and discriminator + if args.GAN_type==0: + self.generator = Generator_weak(self.device,num_classes, image_size, channels, self.latent_dim) + self.discriminator = Discriminator_weak(num_classes, image_size, channels) + elif args.GAN_type==1: + self.generator = Generator_strong(num_classes, image_size, channels, self.latent_dim) + self.discriminator = Discriminator_strong(num_classes, image_size, channels) + + self._init_models() + + # Optimizers + self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=default_opt["lr"], + betas=(default_opt["b1"], default_opt["b2"])) + self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=default_opt["lr"], + betas=(default_opt["b1"], default_opt["b2"])) + + self.FloatTensor = torch.cuda.FloatTensor if self.device else torch.FloatTensor + self.LongTensor = torch.cuda.LongTensor if self.device else torch.LongTensor + + @staticmethod + def _weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm2d") != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + + def _init_models(self): + """ + Initialize model weights + :return: + """ + self.generator.to(self.device) + self.discriminator.to(self.device) + self.adversarial_loss.to(self.device) + self.auxiliary_loss.to(self.device) + + # Initialize weights + self.generator.apply(self._weights_init_normal) + self.discriminator.apply(self._weights_init_normal) + + def _sample_image(self, n_row, image_name, image_save_dir="cifar10-image-0.2-12000-2"): + """ + generating image data for debug. generating images and saving image file + :param n_row: indicate the number of generate image data (will generate n_row*n_row images ) + :param image_name: output image name + :param image_save_dir: output image save dir + :return: + """ + """Saves a grid of generated digits ranging from 0 to n_classes""" + if not os.path.isdir(image_save_dir): + os.makedirs(image_save_dir, exist_ok=True) + + # Sample noise + z = Variable(self.FloatTensor(np.random.normal(0, 1, (n_row * 10, self.latent_dim)))) + # Get labels ranging from 0 to n_classes for n rows + labels = np.array([num for _ in range(10) for num in range(n_row)]) + labels = Variable(self.LongTensor(labels)) + gen_imgs = self.generator(z, labels) + save_image(gen_imgs.data, os.path.join(image_save_dir, "%d.png" % image_name), nrow=n_row, normalize=True) + + def create_model(self): + pass + + def forward(self, x): + raise NotImplemented("ACGAN not support forward inference yet.") + + def get_params(self): + return self.generator.get_params(), self.discriminator.get_params() + + def set_params(self, model_params): + generator_params, discriminator_params = model_params + self.generator.set_params(generator_params) + self.discriminator.set_params(discriminator_params) + + def get_gradients(self, data, model_len): + raise NotImplemented("GAN mode get gradients method is not implemented yet") + + def solve_inner(self, data, num_epochs=1, batch_size=64, sample_interval=-1, verbose=True,model_name=None,gan_model_save_dir=None,class_models=None, anchor=None,): + """ + Solves local optimization problem + :param data: + :param num_epochs: + :param batch_size: + :param sample_interval: + :param verbose: + :return: (soln, comp) + - soln: local optimization solution + - comp: number of FLOPs executed in training process + """ + self.generator.train() + self.discriminator.train() + self.generator.to(self.device) + self.discriminator.to(self.device) + if self.args.GAN_type==1: + ranger = range(num_epochs) + data_loader = DataLoader(data, batch_size=batch_size, shuffle=True,pin_memory=True) + for epoch in ranger: + for i, (imgs, labels) in enumerate(data_loader): + batch_size = imgs.shape[0] + + # Adversarial ground truths + valid = Variable(self.FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) + fake = Variable(self.FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) + + # Configure input + real_imgs = Variable(imgs.type(self.FloatTensor)) + labels = Variable(labels.type(self.LongTensor)) + + # ----------------- + # Train Generator + # ----------------- + self.optimizer_G.zero_grad() + + # Sample noise and labels as generator input + z = Variable(self.FloatTensor(np.random.normal(0, 1, (batch_size, self.latent_dim)))) + gen_labels = Variable(self.LongTensor(np.random.randint(0, self.num_classes, batch_size))) + + # Generate a batch of images + gen_imgs = self.generator(z, gen_labels) + + # Loss measures generator's ability to fool the discriminator + validity, pred_label = self.discriminator(gen_imgs) + valid=valid.squeeze(-1) + g_loss = 0.5 * (self.adversarial_loss(validity, valid) + self.auxiliary_loss(pred_label, gen_labels)) + + g_loss.backward() + self.optimizer_G.step() + + # --------------------- + # Train Discriminator + # --------------------- + self.optimizer_D.zero_grad() + + # Loss for real images + real_pred, real_aux = self.discriminator(real_imgs) + d_real_loss = (self.adversarial_loss(real_pred, valid) + self.auxiliary_loss(real_aux, labels)) / 2 + + # Loss for fake images + fake_pred, fake_aux = self.discriminator(gen_imgs.detach()) + fake = fake.squeeze(-1) + d_fake_loss = (self.adversarial_loss(fake_pred, fake) + self.auxiliary_loss(fake_aux, gen_labels)) / 2 + + # Total discriminator loss + d_loss = (d_real_loss + d_fake_loss) / 2 + + # Calculate discriminator accuracy + pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0) + gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0) + d_acc = np.mean(np.argmax(pred, axis=1) == gt) + + d_loss.backward() + self.optimizer_D.step() + + if verbose: + print( + "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" + % (epoch, num_epochs, i, len(data_loader), d_loss.item(), 100 * d_acc, g_loss.item()) + ) + + batches_done = epoch * len(data_loader) + i + if sample_interval > 0 and batches_done % sample_interval == 0: + self._sample_image(n_row=len(data.classes), image_name=batches_done) + + elif self.args.GAN_type==0: + ranger = range(num_epochs) + data_loader = DataLoader(data, batch_size=batch_size, shuffle=True,pin_memory=True) + for epoch in ranger: + for i, (imgs, labels) in enumerate(data_loader): + batch_size = imgs.shape[0] + + # Adversarial ground truths + valid = Variable(self.FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) + fake = Variable(self.FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) + + + # ----------------- + # Train Generator + # ----------------- + self.optimizer_G.zero_grad() + z = Variable(self.FloatTensor(np.random.normal(0, 1, (batch_size, self.latent_dim)))) + gen_labels = Variable(self.LongTensor(np.random.randint(0, self.num_classes, batch_size))) + + gen_imgs = self.generator(z, gen_labels) + + #get feature + feature = [[] for i in range(self.num_classes)] + for imgi in range(len(gen_imgs)): + label = gen_labels[imgi] + output,feature_i,_,_,_,_ = class_models[label](gen_imgs[imgi].unsqueeze(0),out_feature = True) + feature[label].append(feature_i) + class_mem = torch.zeros(self.num_classes).to(self.device) + feature_now =[] + for label in range(self.num_classes): + if feature[label] != []: + class_mem[label] = True + tmp = (torch.mean(torch.cat(feature[label],dim=0),dim=0)).unsqueeze(0) + feature_now.append(tmp) + feature_now = torch.cat(feature_now,dim=0) + feature_anchor = anchor[class_mem] + + # ##10-class data + # loss1 = self.weak_loss(feature_now,feature_anchor) + # z1, z2 = torch.split(z, z.size(0)//2, dim=0) + # gen_img1, gen_img2 = torch.split(gen_imgs, z1.size(0), dim=0) + # lz = torch.mean(torch.abs(gen_img1 - gen_img2)) / torch.mean(torch.abs(z1 - z2)) + # eps = 1 * 1e-5 + # loss2 = 5000/(lz + eps) + # g_loss = (loss1 + loss2) + + ##100-class data + loss1 = self.weak_loss(feature_now,feature_anchor)*0.01 + z1, z2 = torch.split(z, z.size(0)//2, dim=0) + gen_img1, gen_img2 = torch.split(gen_imgs, z1.size(0), dim=0) + lz = torch.mean(torch.abs(gen_img1 - gen_img2)) / torch.mean(torch.abs(z1 - z2)) + eps = 1 * 1e-5 + loss2 = 100/(lz + eps) + g_loss = 0.5 * (loss1 + loss2) + + g_loss.backward() + self.optimizer_G.step() + + if verbose: + print( + "[Epoch %d/%d] [Batch %d/%d] [G loss1: %f, div_loss: %f]" + % (epoch, num_epochs, i, len(data_loader), loss1.item(), loss2.item()) + ) + + batches_done = epoch * len(data_loader) + i + if sample_interval > 0 and (batches_done+1) % 156 == 0: + self._sample_image(n_row=len(data.classes), image_name=batches_done) + + solution = self.get_params() + + comp = 0 # compute cost + return solution, comp + + def solve_iters(self, data, num_iters=1, batch_size=64): + """ + Solves local optimization problem + :param data: + :param num_iters: + :param batch_size: + :return: + """ + raise NotImplemented("GAN mode solve iter method is not implemented yet") + + def test(self, test_sets): + self.generator.eval() + self.discriminator.eval() + self.generator.to(self.device) + self.discriminator.to(self.device) + + test_g_loss = [] + test_d_loss = [] + d_correct = 0 + batch_size = 1000 + with torch.no_grad(): + data_loader = DataLoader(test_sets, batch_size=batch_size, shuffle=True) + for data, target in data_loader: + data, target = data.to(self.device), target.to(self.device) + batch_size = data.shape[0] + + # Adversarial ground truths + valid = Variable(self.FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) + fake = Variable(self.FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) + + # Configure input + real_imgs = Variable(data.type(self.FloatTensor)) + labels = Variable(target.type(self.LongTensor)) + + # ----------------- + # Generator + # ----------------- + self.optimizer_G.zero_grad() + + # Sample noise and labels as generator input + z = Variable(self.FloatTensor(np.random.normal(0, 1, (batch_size, self.latent_dim)))) + gen_labels = Variable(self.LongTensor(np.random.randint(0, self.num_classes, batch_size))) + + # Generate a batch of images + gen_imgs = self.generator(z, gen_labels) + + # Loss measures generator's ability to fool the discriminator + validity, pred_label = self.discriminator(gen_imgs) + g_loss = 0.5 * (self.adversarial_loss(validity, valid) + self.auxiliary_loss(pred_label, gen_labels)) + test_g_loss.append(g_loss.item()) + + # --------------------- + # Discriminator + # --------------------- + # Loss for real images + real_pred, real_aux = self.discriminator(real_imgs) + d_real_loss = (self.adversarial_loss(real_pred, valid) + self.auxiliary_loss(real_aux, labels)) / 2 + + # Loss for fake images + fake_pred, fake_aux = self.discriminator(gen_imgs.detach()) + d_fake_loss = (self.adversarial_loss(fake_pred, fake) + self.auxiliary_loss(fake_aux, gen_labels)) / 2 + + # Total discriminator loss + d_loss = (d_real_loss + d_fake_loss) / 2 + test_d_loss.append(d_loss.item()) + + # Calculate discriminator accuracy + pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0) + gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0) + d_acc = np.sum(np.argmax(pred, axis=1) == gt) + d_correct += d_acc + + test_g_loss = np.mean(np.array(test_g_loss)) + test_d_loss = np.mean(np.array(test_d_loss)) + + return d_correct, (test_g_loss, test_d_loss) + + def save(self, name, model_save_dir="ACGAN_model"): + """ + save model parameters + :param name: model parameter file name + :param model_save_dir: output dir + :return: + """ + if not os.path.isdir(model_save_dir): + os.makedirs(model_save_dir, exist_ok=True) + + generator_path = os.path.join(model_save_dir, "{}-generator.pt".format(name)) + discriminator_path = os.path.join(model_save_dir, "{}-discriminator.pt".format(name)) + torch.save(self.generator.get_params(), generator_path) + torch.save(self.discriminator.get_params(), discriminator_path) + + def load_model(self, name, model_save_dir="ACGAN_model"): + generator_path = os.path.join(model_save_dir, "{}-generator.pt".format(name)) + print(generator_path) + assert os.path.isfile(generator_path), "Generator model file doesn't exist" + self.generator.set_params(torch.load(generator_path)) + + def generate_data(self, target_labels): + """ + generate fake data corresponding to target labels + :param target_labels: target labels for new fake data (list or numpy array list) + :return: fake data tensor (image tensor) + """ + # Sample noise + z = Variable(self.FloatTensor(np.random.normal(0, 1, (len(target_labels), self.latent_dim)))).to(self.device) + labels = np.array(target_labels) + labels = Variable(self.LongTensor(labels)).to(self.device) + self.generator.eval() + return self.generator(z, labels) + + @staticmethod + def results_to_dataset_tensor(generated_data, dataset_name): + """ + transform GAN model output data to dataset tensor + :param generated_data: GAN model generated data + :param dataset_name: target dataset name + :return: + """ + transform_method_name = "gan_tensor_to_%s_data" % dataset_name + transform_method_path = "utils.dataset_utils" + import importlib + mod = importlib.import_module(transform_method_path) + transform_method = getattr(mod, transform_method_name) + return transform_method(generated_data) + diff --git a/The_Collaborative_Training_Phase/models/base.py b/The_Collaborative_Training_Phase/models/base.py new file mode 100644 index 0000000..cf4abf2 --- /dev/null +++ b/The_Collaborative_Training_Phase/models/base.py @@ -0,0 +1,106 @@ +from abc import abstractmethod +import torch +from torch import nn +from torch.utils.data import DataLoader +from tqdm import trange + + +class BaseModel(nn.Module): + def __init__(self, num_classes, optimizer, device=None): + super(BaseModel, self).__init__() + self.num_classes = num_classes + self.device = 'cpu' if not device else device + + if optimizer is None: + return + + self.create_model() + if isinstance(optimizer, (tuple, list)): + opt, sch = optimizer + self.optimizer = opt(self.parameters()) + self.scheduler = sch(self.optimizer) + else: + self.optimizer = optimizer(self.parameters()) + self.scheduler = None + + @abstractmethod + def create_model(self): + pass + + @abstractmethod + def forward(self, x): + pass + + def set_params(self, model_params): + self.load_state_dict(model_params) + + def get_params(self): + return self.state_dict() + + def get_gradients(self, data, model_len): + pass + + def solve_inner(self, data, num_epochs=1, batch_size=32, verbose=False): + self.train() # set train mode + + ranger = trange(num_epochs, desc='Epoch: ', leave=False, ncols=120) if verbose else range(num_epochs) + model = self.to(self.device) + for _ in ranger: + for X, y in DataLoader(data, batch_size=batch_size, shuffle=True): + source, target = X.to(self.device), y.to(self.device) + self.optimizer.zero_grad() + output = model(source) + loss = self.loss(output, target) + loss.backward() + self.optimizer.step() + + if self.scheduler: + self.scheduler.step() + + solution = self.get_params() + comp = 0 + return solution, comp + + def solve_iters(self, data, num_iters=1, batch_size=32): + self.train() # set train mode + + model = self.to(self.device) + total_iter = 0 + + while total_iter < num_iters: + for iter, (X, y) in enumerate(DataLoader(data, batch_size=batch_size, shuffle=True)): + source, target = X.to(self.device), y.to(self.device) + self.optimizer.zero_grad() + output = model(source) + loss = self.loss(output, target) + loss.backward() + self.optimizer.step() + + total_iter += 1 + if total_iter >= num_iters: + break + + solution = self.get_params() + comp = 0 + return solution, comp + + def test(self, test_sets): + self.eval() + model = self.to(self.device) + + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in DataLoader(test_sets, batch_size=1000): + data, target = data.to(self.device), target.to(self.device) + output = model(data) + test_loss += self.loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_sets) + + return correct, test_loss + + def close(self): + pass diff --git a/The_Collaborative_Training_Phase/models/resnet.py b/The_Collaborative_Training_Phase/models/resnet.py new file mode 100644 index 0000000..10a017d --- /dev/null +++ b/The_Collaborative_Training_Phase/models/resnet.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .BasicModule import BasicModule +import copy + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): + super(BasicBlock, self).__init__() + self.norm = norm + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): + super(Bottleneck, self).__init__() + self.norm = norm + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out, out, out, out, out, out + + +class ResNet(BasicModule): + def __init__(self, model_name, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): + super(ResNet, self).__init__() + self.model_name = model_name + self.in_planes = 64 + self.norm = norm + + self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.adaptiveAvgPool2d = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride, self.norm)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x, out_feature=False): + out1 = F.relu(self.bn1(self.conv1(x))) + out2 = self.layer1(out1) + out3 = self.layer2(out2) + out4 = self.layer3(out3) + out5 = self.layer4(out4) + out6 = self.adaptiveAvgPool2d(out5) #512*1*1 + feature = out6.view(out6.size(0), -1) + out = self.linear(feature) + if out_feature == False: + return out + else: + return out,feature,out5.view(out5.size(0), -1),out4.view(out4.size(0), -1),out3.view(out3.size(0), -1),out2.view(out2.size(0), -1) + + def pass_remain_layers(self, x, out_feature=False): + out3 = self.layer2(x) + out4 = self.layer3(out3) + out5 = self.layer4(out4) + out6 = F.avg_pool2d(out5, 4) + feature = out6.view(out6.size(0), -1) + out = self.linear(feature) + if out_feature == False: + return out + else: + return out,feature,out5.view(out5.size(0), -1),out4.view(out4.size(0), -1),out3.view(out3.size(0), -1) + +def ResNet18(channel=3, num_classes=10, norm='batchnorm'): + model_name = 'resnet18' + return ResNet(model_name,BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm=norm) + +def ResNet34(num_classes=10): + model_name = 'resnet34' + return ResNet(model_name, BasicBlock, [3,4,6,3], num_classes) + +def ResNet50(num_classes=10): + model_name = 'resnet50' + return ResNet(model_name, Bottleneck, [3,4,6,3], num_classes) + +def ResNet101(num_classes=10): + model_name = 'resnet101' + return ResNet(model_name, Bottleneck, [3,4,23,3], num_classes) + +def ResNet152(num_classes=10): + model_name = 'resnet152' + return ResNet(model_name, Bottleneck, [3,8,36,3], num_classes) + +if __name__ == '__main__': + net = ResNet18(num_classes=10) + y = net(torch.randn(20, 3, 112, 112), out_feature=True) + print(y[0].shape,y[1].shape) + \ No newline at end of file diff --git a/The_Collaborative_Training_Phase/utils/__init__.py b/The_Collaborative_Training_Phase/utils/__init__.py new file mode 100644 index 0000000..6fe3279 --- /dev/null +++ b/The_Collaborative_Training_Phase/utils/__init__.py @@ -0,0 +1,2 @@ +from .visualize import Visualizer +from .log import logging \ No newline at end of file diff --git a/The_Collaborative_Training_Phase/utils/log.py b/The_Collaborative_Training_Phase/utils/log.py new file mode 100644 index 0000000..214bde3 --- /dev/null +++ b/The_Collaborative_Training_Phase/utils/log.py @@ -0,0 +1,27 @@ +# encoding: utf-8 +from __future__ import print_function +import os +import logging + + +def init_log(output_dir): + logging.basicConfig(level=logging.INFO, + format='%(asctime)s %(message)s', + datefmt='%Y%m%d-%H:%M:%S', + filename=os.path.join(output_dir, 'log.log'), + filemode='w') + console = logging.StreamHandler() + console.setLevel(logging.INFO) + logging.getLogger('').addHandler(console) + return logging + +def close_log(): + handlers = logging.getLogger('').handlers[:] + for handler in handlers: + logging.getLogger('').removeHandler(handler) + handler.close() + +if __name__ == '__main__': + logging = init_log("../test/") + _print = logging.info + _print('Train Epoch: {}/{} ...'.format(1, 2)) diff --git a/The_Collaborative_Training_Phase/utils/visualize.py b/The_Collaborative_Training_Phase/utils/visualize.py new file mode 100644 index 0000000..fa68e77 --- /dev/null +++ b/The_Collaborative_Training_Phase/utils/visualize.py @@ -0,0 +1,75 @@ +#coding:utf8 +import visdom +import time +import numpy as np +import torch + +class Visualizer(object): + def denormalize(self,x_hat,mean=None,std=None): + x = x_hat.clone().detach().cpu() + if mean != None and std != None: + mean = torch.tensor(mean).reshape(3, 1, 1) + std = torch.tensor(std).reshape(3, 1, 1) + x = x * std + mean + + x = x.mul_(255).add_(0.5).clamp_(0, 255) + return x.detach() + + def __init__(self, env='default', **kwargs): + self.vis = visdom.Visdom(env=env, **kwargs) + self.index = 1 + self.log_text = '' + + def reinit(self,env='default',**kwargs): + self.vis = visdom.Visdom(env=env,**kwargs) + return self + + def plot_many(self, d): + for k, v in d.items(): + self.plot(k, v) + + def img_many(self, d): + for k, v in d.items(): + self.img(k, v) + + def plot(self, name, y,**kwargs): + x = self.index.get(name, 0) + self.vis.line(Y=np.array([y]), X=np.array([x]), + win=name, + opts=dict(title=name), + update=None if x == 0 else 'append', + **kwargs + ) + self.index[name] = x + 1 + + def plot_curves(self, d, iters, title='loss', xlabel='iters', ylabel='accuracy'): + name = list(d.keys()) + val = list(d.values()) + if len(val) == 1: + y = np.array(val) + else: + y = np.array(val).reshape(-1, len(val)) + self.vis.line(Y=y, + X=np.array([self.index]), + win=title, + opts=dict(legend=name, title = title, xlabel=xlabel, ylabel=ylabel), + update=None if self.index == 0 else 'append') + self.index = iters + + def img(self, name, img_, mean=None, std=None, **kwargs): + img = self.denormalize(img_,mean,std) + self.vis.images(img.numpy(), + win=name, + opts=dict(title=name), + **kwargs + ) + + def log(self,info,win='log_text'): + self.log_text += ('[{time}] {info}
'.format( + time=time.strftime('%m%d_%H%M%S'),\ + info=info)) + self.vis.text(self.log_text,win) + + def __getattr__(self, name): + return getattr(self.vis, name) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1b83e2f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,126 @@ +argon2-cffi==21.3.0 +argon2-cffi-bindings==21.2.0 +asttokens==2.0.5 +attrs==23.1.0 +backcall==0.2.0 +bleach==4.1.0 +Brotli==1.0.9 +catboost==1.2.5 +certifi==2024.7.4 +cffi==1.15.1 +charset-normalizer==2.1.1 +contourpy==1.2.1 +cycler==0.12.1 +Cython==3.0.11 +debugpy==1.6.7 +decorator==5.1.1 +defusedxml==0.7.1 +entrypoints==0.4 +et-xmlfile==1.1.0 +exceptiongroup==1.2.0 +executing==0.8.3 +filelock==3.13.1 +fonttools==4.53.1 +fsspec==2024.6.1 +gmpy2==2.1.2 +graphviz==0.20.3 +h5py==3.12.1 +huggingface-hub==0.24.5 +idna==3.7 +importlib-metadata==6.0.0 +importlib-resources==5.2.0 +ipykernel==6.9.1 +ipython==8.15.0 +ipython-genutils==0.2.0 +ipywidgets==7.6.0 +jedi==0.19.1 +Jinja2==3.1.4 +joblib==1.4.2 +json5==0.9.6 +jsonpatch==1.33 +jsonpointer==3.0.0 +jsonschema==4.19.2 +jsonschema-specifications==2023.7.1 +jupyter==1.0.0 +jupyter_client==7.4.9 +jupyter-console==6.4.4 +jupyter_core==4.11.2 +jupyterlab-pygments==0.1.2 +kiwisolver==1.4.5 +MarkupSafe==2.1.3 +matplotlib==3.9.2 +matplotlib-inline==0.1.6 +mistune==0.8.4 +mkl-fft==1.3.1 +mkl-random==1.2.2 +mkl-service==2.4.0 +mpmath==1.3.0 +nbconvert==5.5.0 +nbformat==5.1.3 +nest-asyncio==1.6.0 +networkx==3.3 +notebook==6.4.12 +numpy==1.24.3 +opencv-python==4.10.0.84 +openpyxl==3.1.5 +packaging==24.1 +pandas==2.2.2 +pandocfilters==1.5.0 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==9.4.0 +pip==24.0 +plotly==5.23.0 +ply==3.11 +prometheus-client==0.14.1 +prompt-toolkit==3.0.36 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pycparser==2.21 +Pygments==2.15.1 +pyparsing==3.1.2 +PyQt5-sip==12.11.0 +PySocks==1.7.1 +python-dateutil==2.8.2 +pytz==2024.1 +PyYAML==6.0.1 +pyzmq==23.2.0 +qtconsole==4.6.0 +referencing==0.30.2 +requests==2.32.3 +rpds-py==0.10.6 +safetensors==0.4.4 +scikit-learn==1.5.1 +scipy==1.14.0 +seaborn==0.13.2 +Send2Trash==1.8.0 +setuptools==72.1.0 +sip==6.6.2 +six==1.16.0 +stack-data==0.2.0 +sympy==1.12 +tenacity==9.0.0 +terminado==0.17.1 +testpath==0.6.0 +threadpoolctl==3.5.0 +timm==1.0.8 +toml==0.10.2 +torch==2.2.1 +torchaudio==2.2.1 +torchnet==0.0.4 +torchvision==0.17.1 +tornado==6.4.1 +tqdm==4.66.5 +traitlets==5.14.3 +triton==2.2.0 +typing_extensions==4.11.0 +tzdata==2024.1 +urllib3==2.2.2 +visdom==0.2.4 +wcwidth==0.2.5 +webencodings==0.5.1 +websocket-client==1.8.0 +wheel==0.43.0 +widgetsnbextension==3.5.2 +zipp==3.17.0