Skip to content

Commit

Permalink
Merge branch 'main' into cyclic_wf
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster authored May 16, 2024
2 parents 67c70c0 + 70c385e commit 032eaae
Show file tree
Hide file tree
Showing 114 changed files with 4,420 additions and 216 deletions.
14 changes: 8 additions & 6 deletions docs/real_world_fl/cloud_deployment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ To run NVFlare dashboard on Azure, run:
.. note::

The script also requires sshpass and jq. Both can be installed on Ubuntu, with:
The script also requires sshpass, dig and jq. All can be installed on Ubuntu, with:

.. code-block:: shell
sudo apt install sshpass jq
sudo apt install sshpass bind9-dnsutils jq
Users only need to enter an email address and press Enter. This user needs to remember this email and the temporary password that will be provided, as
this is the login credentials for the NVFLARE Dashboard once the Dashboard is up and running.
Expand Down Expand Up @@ -101,11 +101,11 @@ To run NVFlare dashboard on AWS, run:
.. note::

The script also requires sshpass and jq. They can be installed on Ubuntu, with:
The script also requires sshpass, dig and jq. They can be installed on Ubuntu, with:

.. code-block:: shell
sudo apt install sshpass jq
sudo apt install sshpass bind9-dnsutils jq
AWS manages authentications via AWS access_key and access_secret, you will need to have these credentials before you can start creating AWS infrastructure.

Expand All @@ -128,9 +128,10 @@ You can accept all default values by pressing ENTER.

.. code-block:: none
This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed.
This script requires az (Azure CLI), sshpass dig and jq. Now checking if they are installed.
Checking if az exists. => found
Checking if sshpass exists. => found
Checking if dig exists. => found
Checking if jq exists. => found
Cloud VM image, press ENTER to accept default Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest:
Cloud VM size, press ENTER to accept default Standard_B2ms:
Expand Down Expand Up @@ -190,9 +191,10 @@ You can accept all default values by pressing ENTER.

.. code-block::
This script requires aws (AWS CLI), sshpass and jq. Now checking if they are installed.
This script requires aws (AWS CLI), sshpass, dig and jq. Now checking if they are installed.
Checking if aws exists. => found
Checking if sshpass exists. => found
Checking if dig exists. => found
Checking if jq exists. => found
If the server requires additional dependencies, please copy the requirements.txt to /home/nvflare/workspace/aws/nvflareserver/startup.
Press ENTER when it's done or no additional dependencies.
Expand Down
Binary file modified docs/resources/client_api.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"id": "mlflow_receiver_with_tracking_uri",
"path": "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver",
"args": {
tracking_uri = "file:///{WORKSPACE}/{JOB_ID}/mlruns"
"kwargs": {
"experiment_name": "hello-pt-experiment",
"run_name": "hello-pt-with-mlflow",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def evaluate(input_weights):
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
global_step = input_model.current_round * local_epochs * batch_size + epoch * batch_size + i
global_step = input_model.current_round * steps + epoch * len(trainloader) + i
mlflow.log_metric("loss", running_loss / 2000, global_step)
running_loss = 0.0

Expand Down
46 changes: 44 additions & 2 deletions examples/hello-world/step-by-step/cifar10/sag/sag.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@
"source": [
"! nvflare job create -j /tmp/nvflare/jobs/cifar10_sag_pt -w sag_pt_in_proc \\\n",
"-f meta.conf min_clients=2 \\\n",
"-f config_fed_client.conf app_script=train.py app_config=\"--batch_size 4 --dataset_path {CIFAR10_ROOT} --num_workers 2\" \\\n",
"-f config_fed_server.conf num_rounds=5 \\\n",
"-f config_fed_client.conf app_script=train_with_mlflow.py app_config=\"--batch_size 4 --dataset_path {CIFAR10_ROOT} --num_workers 2\" \\\n",
"-f config_fed_server.conf num_rounds=2 \\\n",
"-sd ../code/fl \\\n",
"-force"
]
Expand Down Expand Up @@ -289,6 +289,48 @@
"The next 5 examples will use the same ScatterAndGather workflow, but will demonstrate different execution APIs and feature.\n",
"In the next example [sag_deploy_map](../sag_deploy_map/sag_deploy_map.ipynb), we will learn about the deploy_map configuration for deployment of apps to different sites."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a49b430b-a65b-4b1e-8793-9b3befcfcfd9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!tree /tmp/nvflare/jobs/cifar10_sag_pt_workspace/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50594df7-b4c9-4e5e-944a-403b5a105c27",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!mlflow ui --port 5000 --backend-store-uri /tmp/nvflare/jobs/cifar10_sag_pt_workspace/server/simulate_job/mlruns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af2b6628-61af-4bc8-84d4-a9876a27c7c2",
"metadata": {},
"outputs": [],
"source": [
"!tensorboard --logdir=/tmp/nvflare/jobs/cifar10_sag_pt_workspace/server/simulate_job/tb_events"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3ad11c3-6ef7-46cd-8778-0090505b14e1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
4 changes: 4 additions & 0 deletions integration/monai/examples/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Examples of MONAI-NVFlare Integration

### [Converting MONAI Code to a Federated Learning Setting](./mednist/README.md)
A tutorial to show how simple it can be to run an end-to-end classification pipeline with MONAI
and deploy it in a federated learning setting using NVFlare.

### [Simulated Federated Learning for 3D spleen CT segmentation](./spleen_ct_segmentation_sim/README.md)
An example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)
to train a medical image analysis model using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629))
Expand Down
3 changes: 3 additions & 0 deletions integration/monai/examples/mednist/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# nvflare artifacts for this example
fedavg_workspace
jobs
18 changes: 18 additions & 0 deletions integration/monai/examples/mednist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## Converting MONAI Code to a Federated Learning Setting

In this tutorial, we will introduce how simple it can be to run an end-to-end classification pipeline with MONAI
and deploy it in a federated learning setting using NVFlare.

### 1. Standalone training with MONAI
[monai_101.ipynb](./monai_101.ipynb) is based on the [MONAI 101 classification tutorial](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb) and shows each step required in only a few lines of code, including

- Dataset download
- Data pre-processing
- Define a DenseNet-121 and run training
- Check the results on test dataset

### 2. Federated learning with MONAI
[monai_101_fl.ipynb](./monai_101_fl.ipynb) shows how we can simply put the code introduced above into a Python script and convert it to running in an FL scenario using NVFlare.

To achieve this, we utilize the [`FedAvg`](https://arxiv.org/abs/1602.05629) algorithm and NVFlare's [Client
API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type.html#client-api).
169 changes: 169 additions & 0 deletions integration/monai/examples/mednist/code/monai_mednist_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MONAI Example adopted from https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb
#
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import tempfile
from pathlib import Path

import numpy as np
import torch
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.engines import SupervisedTrainer
from monai.handlers import StatsHandler, TensorBoardStatsHandler
from monai.inferers import SimpleInferer
from monai.networks import eval_mode
from monai.networks.nets import densenet121
from monai.transforms import Compose, EnsureChannelFirstD, LoadImageD, ScaleIntensityD

# (1) import nvflare client API
import nvflare.client as flare

# (optional) metrics
from nvflare.client.tracking import SummaryWriter

print_config()


def main():
# (2) initializes NVFlare client API
flare.init()

# Setup data directory
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

# Use MONAI transforms to preprocess data
transform = Compose(
[
LoadImageD(keys="image", image_only=True),
EnsureChannelFirstD(keys="image"),
ScaleIntensityD(keys="image"),
]
)

# Prepare datasets using MONAI Apps
dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section="training", download=True)

# Define a network and a supervised trainer

# If available, we use GPU to speed things up.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

max_epochs = 1 # rather than 5 epochs, we run 5 FL rounds with 1 local epoch each.
model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE)

train_loader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4)

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
trainer = SupervisedTrainer(
device=torch.device(DEVICE),
max_epochs=max_epochs,
train_data_loader=train_loader,
network=model,
optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),
loss_function=torch.nn.CrossEntropyLoss(),
inferer=SimpleInferer(),
train_handlers=StatsHandler(),
)

# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
summary_writer = SummaryWriter()
train_tensorboard_stats_handler = TensorBoardStatsHandler(summary_writer=summary_writer)
train_tensorboard_stats_handler.attach(trainer)

# (optional) calculate total steps
steps = max_epochs * len(train_loader)
# Run the training

while flare.is_running():
# (3) receives FLModel from NVFlare
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

# (4) loads model from NVFlare and sends it to GPU
trainer.network.load_state_dict(input_model.params)
trainer.network.to(DEVICE)

trainer.run()

# (5) wraps evaluation logic into a method to re-use for
# evaluation on both trained and received model
def evaluate(input_weights):
# Create model for evaluation
eval_model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE)
eval_model.load_state_dict(input_weights)

# Check the prediction on the test dataset
dataset_dir = Path(root_dir, "MedNIST")
class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir())
testdata = MedNISTDataset(
root_dir=root_dir, transform=transform, section="test", download=False, runtime_cache=True
)
correct = 0
total = 0
max_items_to_print = 10
_print = 0
with eval_mode(eval_model):
for item in DataLoader(testdata, batch_size=512, num_workers=0): # changed to do batch processing
prob = np.array(eval_model(item["image"].to(DEVICE)).detach().to("cpu"))
pred = [class_names[p] for p in prob.argmax(axis=1)]
gt = item["class_name"]
# changed the logic a bit from tutorial to compute accuracy on full test set
# but only print for some.
for _gt, _pred in zip(gt, pred):
if _print < max_items_to_print:
print(f"Class prediction is {_pred}. Ground-truth: {_gt}")
_print += 1

# compute accuracy
total += 1
correct += float(_pred == _gt)

print(f"Accuracy of the network on the {total} test images: {100 * correct // total} %")
return correct / total

# (6) evaluate on received model for model selection
accuracy = evaluate(input_model.params)
summary_writer.add_scalar(tag="global_model_accuracy", scalar=accuracy, global_step=input_model.current_round)

# (7) construct trained FL model
output_model = flare.FLModel(
params=trainer.network.cpu().state_dict(),
metrics={"accuracy": accuracy},
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)
# (8) send model back to NVFlare
flare.send(output_model)


if __name__ == "__main__":
main()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added integration/monai/examples/mednist/figs/tb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 032eaae

Please sign in to comment.