Skip to content

Commit

Permalink
feat(baselines) editing simII data generation
Browse files Browse the repository at this point in the history
  • Loading branch information
chancejohnstone committed Dec 2, 2024
1 parent 52dbab3 commit 55f8366
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
4 changes: 2 additions & 2 deletions baselines/fedht/fedht/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def client_fn(context: Context) -> Client:
test_dataset = train_dataset = MyDataset(
X_train[int(partition_id), :, :], y_train[:, int(partition_id)]
)
trainloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True)
trainloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=False)
testloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False)

# define model and set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand Down
19 changes: 9 additions & 10 deletions baselines/fedht/fedht/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
import random
import torch
import gzip

import flwr as fl
import hydra
Expand Down Expand Up @@ -68,23 +69,21 @@ def main(cfg: DictConfig):
num_classes = cfg.num_classes

# import data from fedht/data folder
with open('fedht/data/simII_train.pkl', 'rb') as file:
dataset = pickle.load(file)
# with open('fedht/data/simII_train.pkl', 'rb') as file:
# dataset = pickle.load(file)

with open('fedht/data/simII_test.pkl', 'rb') as file:
test_dataset = pickle.load(file)
# with open('fedht/data/simII_test.pkl', 'rb') as file:
# test_dataset = pickle.load(file)

X_test, y_test = test_dataset

# dataset = sim_data(num_obs, num_clients, num_features, 1, 1)
# X_test, y_test = sim_data(num_obs, 1, num_features, 1, 1)
# test_dataset = MyDataset(X_test[0, :, :], y_test[:, 0])
# simulate data
X_train, y_train, X_test, y_test = sim_data(200, num_clients, 1000, 1, 1)
train_dataset = X_train, y_train
test_dataset = MyDataset(X_test, y_test[:,0])
testloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False)

# set client function
client_fn = generate_client_fn_simII(
dataset,
train_dataset,
cfg=cfg
)

Expand Down
10 changes: 6 additions & 4 deletions baselines/fedht/fedht/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def partition_data(data, num_partitions):
def sim_data(ni: int, num_clients: int, num_features: int, alpha=1, beta=1):
"""Simulate data for simII."""

np.random.seed(2025)
np.random.seed(2024)

# generate client-based model coefs
u = np.random.normal(0, alpha, num_clients)
x = np.zeros((num_features, num_clients))
x[0:99, :] = np.random.multivariate_normal(u, np.diag(np.ones(num_clients)), 99)
x[0:100, :] = np.random.multivariate_normal(u, np.diag(np.ones(num_clients)), 100)

# generate train observations
ivec = np.arange(1, num_features + 1)
Expand Down Expand Up @@ -89,15 +89,17 @@ def sim_data(ni: int, num_clients: int, num_features: int, alpha=1, beta=1):
# B = np.random.normal(0, beta, num_features)
# v = np.random.multivariate_normal(B, np.diag(np.ones(num_features)), num_clients)

error = np.random.multivariate_normal(u, np.diag(np.ones(num_clients)), ntest)
error2 = np.random.multivariate_normal(u, np.diag(np.ones(num_clients)), ntest)
xtest = np.zeros((num_features, num_clients))
xtest[0:100, :] = np.random.multivariate_normal(u, np.diag(np.ones(num_clients)), 100)
ztest = np.zeros((num_clients, ntest, num_features))
ytest = np.zeros((ni, num_clients))

# (num_clients, ni, num_feaures)
for i in range(ztest.shape[0]):
# train
ztest[i, :, :] = np.random.multivariate_normal(v[i], vari, ntest)
hold = np.matmul(ztest[i, :, :], x[:, i]) + error[:, i]
hold = np.matmul(ztest[i, :, :], xtest[:, i]) + error2[:, i]
ytest[:, i] = np.exp(hold) / (1 + np.exp(hold))

for j in range(num_clients):
Expand Down

0 comments on commit 55f8366

Please sign in to comment.