Skip to content

Commit

Permalink
Add the hello-pt-resnet example (#2954)
Browse files Browse the repository at this point in the history
* Add the hello-pt-resnet example.

* Removed the no use SimpleNetwork.

* codestyle fix for hello-pt-resnet example.

* renamed the simple_network.py -> resnet_18.py. And the resnet18 link to ReadMe.

* updated license year.

* codestyle fix.

* black codestyle fix.

* codestyle fix.

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
yhwen and YuanTingHsieh authored Sep 26, 2024
1 parent 03ea272 commit a552e99
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 0 deletions.
20 changes: 20 additions & 0 deletions examples/hello-world/hello-pt-resnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Hello PyTorch ResNet

Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier
using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629))
and [PyTorch](https://pytorch.org/) as the deep learning training framework. Comparing with the Hello PyTorch example, it uses the torchvision ResNet,
instead of the SimpleNetwork.

> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the client train code.
The Job API only supports the object instance created directly out of the Python Class. It does not support
the object instance created through using the Python function. Comparing with the hello-pt example,
if we replace the SimpleNetwork() object with the resnet18(num_classes=10),
the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function.
As shown in the [torchvision reset](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L684-L705),
the resnet18 is a Python function, which creates and returns a ResNet object. The job API can
only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18".

This example demonstrates how to wrap up the resnet18 Python function into a Resnet18 Python class. Then uses the Resnet18(num_classes=10)
object instance in the job API. After replacing the SimpleNetwork() with the Resnet18(num_classes=10),
you can follow the exact same steps in the hello-pt example to run the fedavg_script_runner_pt.py.
41 changes: 41 additions & 0 deletions examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.resnet_18 import Resnet18

from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
from nvflare.job_config.script_runner import ScriptRunner

if __name__ == "__main__":
n_clients = 2
num_rounds = 2
train_script = "src/hello-pt_cifar10_fl.py"

job = FedAvgJob(
name="hello-pt_cifar10_fedavg",
n_clients=n_clients,
num_rounds=num_rounds,
initial_model=Resnet18(num_classes=10),
)

# Add clients
for i in range(n_clients):
executor = ScriptRunner(
script=train_script,
script_args="", # f"--batch_size 32 --data_path /tmp/data/site-{i}"
)
job.to(executor, f"site-{i + 1}")

# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")
3 changes: 3 additions & 0 deletions examples/hello-world/hello-pt-resnet/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
nvflare~=2.5.0rc
torch
torchvision
103 changes: 103 additions & 0 deletions examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from resnet_18 import Resnet18
from torch import nn
from torch.optim import SGD
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

import nvflare.client as flare
from nvflare.client.tracking import SummaryWriter

DATASET_PATH = "/tmp/nvflare/data"


def main():
batch_size = 4
epochs = 5
lr = 0.01
model = Resnet18(num_classes=10)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
transforms = Compose(
[
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)

flare.init()
sys_info = flare.system_info()
client_name = sys_info["site_name"]

train_dataset = CIFAR10(
root=os.path.join(DATASET_PATH, client_name),
transform=transforms,
download=True,
train=True,
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

summary_writer = SummaryWriter()
while flare.is_running():
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

model.load_state_dict(input_model.params)
model.to(device)

steps = epochs * len(train_loader)
for epoch in range(epochs):
running_loss = 0.0
for i, batch in enumerate(train_loader):
images, labels = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()

predictions = model(images)
cost = loss(predictions, labels)
cost.backward()
optimizer.step()

running_loss += cost.cpu().detach().numpy() / images.size()[0]
if i % 3000 == 0:
print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}")
global_step = input_model.current_round * steps + epoch * len(train_loader) + i
summary_writer.add_scalar(
tag="loss_for_each_batch",
scalar=running_loss,
global_step=global_step,
)
running_loss = 0.0

print("Finished Training")

PATH = "./cifar_net.pth"
torch.save(model.state_dict(), PATH)

output_model = flare.FLModel(
params=model.cpu().state_dict(),
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)

flare.send(output_model)


if __name__ == "__main__":
main()
33 changes: 33 additions & 0 deletions examples/hello-world/hello-pt-resnet/src/resnet_18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

from torchvision.models import ResNet
from torchvision.models._utils import _ovewrite_named_param
from torchvision.models.resnet import BasicBlock, ResNet18_Weights


class Resnet18(ResNet):
def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any):
self.num_classes = num_classes

weights = ResNet18_Weights.verify(weights)

if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

if weights is not None:
super().load_state_dict(weights.get_state_dict(progress=progress))

0 comments on commit a552e99

Please sign in to comment.