Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cyclic WFController #2554

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nvflare/app_common/executors/task_script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, site_name: str, script_path: str, script_args: str = None, re

def run(self):
"""Call the task_fn with any required arguments."""
self.logger.info(f"\n start task run() with full path: {self.script_full_path}")
self.logger.info(f"start task run() with full path: {self.script_full_path}")
try:
curr_argv = sys.argv
builtins.print = log_print if self.redirect_print_to_log else print_fn
Expand Down
24 changes: 0 additions & 24 deletions nvflare/app_common/workflows/base_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from typing import List

from nvflare.apis.fl_constant import FLMetaKey
Expand Down Expand Up @@ -48,7 +47,6 @@ def __init__(
The model_persistor will also save the model after training.

Provides the default implementations for the follow routines:
- def sample_clients(self, min_clients)
- def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
- def update_model(self, aggr_result)

Expand All @@ -74,28 +72,6 @@ def __init__(

self.current_round = None

def sample_clients(self, num_clients):
"""Called by the `run` routine to get a list of available clients.

Args:
min_clients: number of clients to return.

Returns: list of clients.

"""

clients = self.engine.get_clients()

if num_clients <= len(clients):
random.shuffle(clients)
clients = clients[0:num_clients]
else:
self.info(
f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients."
)

return clients

@staticmethod
def _check_results(results: List[FLModel]):
empty_clients = []
Expand Down
60 changes: 60 additions & 0 deletions nvflare/app_common/workflows/cyclic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .wf_controller import WFController


class Cyclic(WFController):
def __init__(
self,
*args,
min_clients: int = 2,
num_rounds: int = 5,
start_round: int = 0,
**kwargs,
):
"""The Cyclic Workflow controller to implement the Cyclic Weight Transfer (CWT) algorithm.

Args:
min_clients (int, optional): The minimum number of clients. Defaults to 2.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): The starting round number. Defaults to 0
"""
super().__init__(*args, **kwargs)

self.min_clients = min_clients
self.num_rounds = num_rounds
self.start_round = start_round
self.current_round = None

def run(self) -> None:
self.info("Start Cyclic.")

model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds

for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round

clients = self.sample_clients(self.min_clients)

for client in clients:
result = self.send_model_and_wait(targets=[client], data=model)[0]
model.params, model.meta = result.params, result.meta

self.save_model(model)

self.info("Finished Cyclic.")
18 changes: 18 additions & 0 deletions nvflare/app_common/workflows/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import gc
import random
from abc import ABC, abstractmethod
from typing import Callable, List, Union

Expand Down Expand Up @@ -343,6 +344,23 @@ def save_model(self, model):
else:
self.error("persistor not configured, model will not be saved")

def sample_clients(self, num_clients):
clients = self.engine.get_clients()

if num_clients < len(clients):
random.shuffle(clients)
clients = clients[0:num_clients]
self.info(
f"num_clients ({num_clients}) is less than the number of available clients. Returning a random subset of {num_clients} clients."
)
elif num_clients > len(clients):
self.info(
f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients."
)
self.info(f"Sampled clients: {[client.name for client in clients]}")

return clients

def stop_controller(self, fl_ctx: FLContext):
self.fl_ctx = fl_ctx
self.finalize()
11 changes: 11 additions & 0 deletions nvflare/app_common/workflows/wf_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,14 @@ def save_model(self, model: FLModel):
None
"""
super().save_model(model)

def sample_clients(self, num_clients):
"""Returns a list of available clients.

Args:
min_clients: number of clients to return.

Returns: list of clients.

"""
return super().sample_clients(num_clients)
Loading