diff --git a/examples/hello-world/hello-pt-resnet/README.md b/examples/hello-world/hello-pt-resnet/README.md new file mode 100644 index 0000000000..4b1a0f51a6 --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/README.md @@ -0,0 +1,20 @@ +# Hello PyTorch ResNet + +Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier +using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) +and [PyTorch](https://pytorch.org/) as the deep learning training framework. Comparing with the Hello PyTorch example, it uses the torchvision ResNet, +instead of the SimpleNetwork. + +> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the client train code. + +The Job API only supports the object instance created directly out of the Python Class. It does not support +the object instance created through using the Python function. Comparing with the hello-pt example, +if we replace the SimpleNetwork() object with the resnet18(num_classes=10), +the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. +As shown in the [torchvision reset](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L684-L705), +the resnet18 is a Python function, which creates and returns a ResNet object. The job API can +only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18". + +This example demonstrates how to wrap up the resnet18 Python function into a Resnet18 Python class. Then uses the Resnet18(num_classes=10) +object instance in the job API. After replacing the SimpleNetwork() with the Resnet18(num_classes=10), +you can follow the exact same steps in the hello-pt example to run the fedavg_script_runner_pt.py. diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py new file mode 100644 index 0000000000..ece630dcdb --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.resnet_18 import Resnet18 + +from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob +from nvflare.job_config.script_runner import ScriptRunner + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + train_script = "src/hello-pt_cifar10_fl.py" + + job = FedAvgJob( + name="hello-pt_cifar10_fedavg", + n_clients=n_clients, + num_rounds=num_rounds, + initial_model=Resnet18(num_classes=10), + ) + + # Add clients + for i in range(n_clients): + executor = ScriptRunner( + script=train_script, + script_args="", # f"--batch_size 32 --data_path /tmp/data/site-{i}" + ) + job.to(executor, f"site-{i + 1}") + + # job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0") diff --git a/examples/hello-world/hello-pt-resnet/requirements.txt b/examples/hello-world/hello-pt-resnet/requirements.txt new file mode 100644 index 0000000000..919cc32ba2 --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/requirements.txt @@ -0,0 +1,3 @@ +nvflare~=2.5.0rc +torch +torchvision diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py new file mode 100644 index 0000000000..860e6f0cac --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from resnet_18 import Resnet18 +from torch import nn +from torch.optim import SGD +from torch.utils.data.dataloader import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + +import nvflare.client as flare +from nvflare.client.tracking import SummaryWriter + +DATASET_PATH = "/tmp/nvflare/data" + + +def main(): + batch_size = 4 + epochs = 5 + lr = 0.01 + model = Resnet18(num_classes=10) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + loss = nn.CrossEntropyLoss() + optimizer = SGD(model.parameters(), lr=lr, momentum=0.9) + transforms = Compose( + [ + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + flare.init() + sys_info = flare.system_info() + client_name = sys_info["site_name"] + + train_dataset = CIFAR10( + root=os.path.join(DATASET_PATH, client_name), + transform=transforms, + download=True, + train=True, + ) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + summary_writer = SummaryWriter() + while flare.is_running(): + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + model.load_state_dict(input_model.params) + model.to(device) + + steps = epochs * len(train_loader) + for epoch in range(epochs): + running_loss = 0.0 + for i, batch in enumerate(train_loader): + images, labels = batch[0].to(device), batch[1].to(device) + optimizer.zero_grad() + + predictions = model(images) + cost = loss(predictions, labels) + cost.backward() + optimizer.step() + + running_loss += cost.cpu().detach().numpy() / images.size()[0] + if i % 3000 == 0: + print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}") + global_step = input_model.current_round * steps + epoch * len(train_loader) + i + summary_writer.add_scalar( + tag="loss_for_each_batch", + scalar=running_loss, + global_step=global_step, + ) + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(model.state_dict(), PATH) + + output_model = flare.FLModel( + params=model.cpu().state_dict(), + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/examples/hello-world/hello-pt-resnet/src/resnet_18.py b/examples/hello-world/hello-pt-resnet/src/resnet_18.py new file mode 100644 index 0000000000..3420fdd741 --- /dev/null +++ b/examples/hello-world/hello-pt-resnet/src/resnet_18.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from torchvision.models import ResNet +from torchvision.models._utils import _ovewrite_named_param +from torchvision.models.resnet import BasicBlock, ResNet18_Weights + + +class Resnet18(ResNet): + def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any): + self.num_classes = num_classes + + weights = ResNet18_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs) + + if weights is not None: + super().load_state_dict(weights.get_state_dict(progress=progress))