Skip to content

Commit

Permalink
Update flwr job object, client, server (#3008)
Browse files Browse the repository at this point in the history
Co-authored-by: Sean Yang <[email protected]>
  • Loading branch information
YuanTingHsieh and SYangster authored Oct 8, 2024
1 parent 39da4f0 commit 0106514
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 21 deletions.
2 changes: 1 addition & 1 deletion examples/hello-world/hello-flower/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ Next, we run 2 Flower clients and Flower Server in parallel using NVFlare while
the TensorBoard metrics to the server at each iteration using NVFlare's metric streaming.

```bash
python job.py --job_name "flwr-pt-tb" --content_dir "./flwr-pt-tb" --stream_metrics --use_client_api
python job.py --job_name "flwr-pt-tb" --content_dir "./flwr-pt-tb" --stream_metrics
```
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,16 @@ class FlowerClient(NumPyClient):
def __init__(self, context: Context):
super().__init__()
self.writer = SummaryWriter()
self.set_context(context)
self.flwr_context = context

if "step" not in context.state.metrics_records:
self.set_step(0)

def set_step(self, step: int):
context = self.get_context()
context.state = RecordSet(metrics_records={"step": MetricsRecord({"step": step})})
self.set_context(context)
self.flwr_context.state = RecordSet(metrics_records={"step": MetricsRecord({"step": step})})

def get_step(self):
context = self.get_context()
return int(context.state.metrics_records["step"]["step"])
return int(self.flwr_context.state.metrics_records["step"]["step"])

def fit(self, parameters, config):
step = self.get_step()
Expand Down
17 changes: 10 additions & 7 deletions examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from typing import List, Tuple

from flwr.common import Metrics, ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig
from flwr.common import Context, Metrics, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from .task import Net, get_weights
Expand Down Expand Up @@ -53,13 +53,16 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
initial_parameters=parameters,
)


# Define config
config = ServerConfig(num_rounds=3)


# Flower ServerApp
app = ServerApp(
config=config,
strategy=strategy,
)
def server_fn(context: Context):
return ServerAppComponents(
strategy=strategy,
config=config,
)


app = ServerApp(server_fn=server_fn)
8 changes: 5 additions & 3 deletions examples/hello-world/hello-flower/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from argparse import ArgumentParser

from nvflare.app_opt.flower.flower_job import FlowerJob
from nvflare.app_opt.flower.flower_pt_job import FlowerPyTorchJob
from nvflare.client.api import ClientAPIType
from nvflare.client.api_spec import CLIENT_API_TYPE_KEY

Expand All @@ -30,10 +30,12 @@ def main():
args = parser.parse_args()

env = {}
if args.use_client_api:
if args.stream_metrics or args.use_client_api:
# needs to init client api to stream metrics
# only external client api works with the current flower integration
env = {CLIENT_API_TYPE_KEY: ClientAPIType.EX_PROCESS_API.value}

job = FlowerJob(
job = FlowerPyTorchJob(
name=args.job_name,
flower_content=args.content_dir,
stream_metrics=args.stream_metrics,
Expand Down
6 changes: 2 additions & 4 deletions nvflare/app_opt/flower/flower_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from nvflare.app_common.widgets.external_configurator import ExternalConfigurator
from nvflare.app_common.widgets.metric_relay import MetricRelay
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.fuel.utils.pipe.cell_pipe import CellPipe
from nvflare.fuel.utils.validation_utils import check_object_type
from nvflare.job_config.api import FedJob
Expand Down Expand Up @@ -104,10 +103,9 @@ def __init__(
# server side - need analytics_receiver
if analytics_receiver:
check_object_type("analytics_receiver", analytics_receiver, AnalyticsReceiver)
self.to_server(analytics_receiver, "analytics_receiver")
else:
analytics_receiver = TBAnalyticsReceiver(events=["fed.analytix_log_stats"])

self.to_server(analytics_receiver, "analytics_receiver")
raise ValueError("Missing analytics receiver on the server side.")

# client side
# cell pipe
Expand Down
88 changes: 88 additions & 0 deletions nvflare/app_opt/flower/flower_pt_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from nvflare.app_common.tie.defs import Constant
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver

from .flower_job import FlowerJob


class FlowerPyTorchJob(FlowerJob):
def __init__(
self,
name: str,
flower_content: str,
min_clients: int = 1,
mandatory_clients: Optional[List[str]] = None,
database: str = "",
server_app_args: list = None,
superlink_ready_timeout: float = 10.0,
configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT,
start_task_timeout=Constant.START_TASK_TIMEOUT,
max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL,
progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT,
per_msg_timeout=10.0,
tx_timeout=100.0,
client_shutdown_timeout=5.0,
stream_metrics=False,
analytics_receiver=None,
extra_env: dict = None,
):
"""
Flower Job.
Args:
name (str): Name of the job.
flower_content (str): Content for the flower job.
min_clients (int, optional): The minimum number of clients for the job. Defaults to 1.
mandatory_clients (List[str], optional): List of mandatory clients for the job. Defaults to None.
database (str, optional): Database string. Defaults to "".
server_app_args (list, optional): List of arguments to pass to the server application. Defaults to None.
superlink_ready_timeout (float, optional): Timeout for the superlink to be ready. Defaults to 10.0 seconds.
configure_task_timeout (float, optional): Timeout for configuring the task. Defaults to Constant.CONFIG_TASK_TIMEOUT.
start_task_timeout (float, optional): Timeout for starting the task. Defaults to Constant.START_TASK_TIMEOUT.
max_client_op_interval (float, optional): Maximum interval between client operations. Defaults to Constant.MAX_CLIENT_OP_INTERVAL.
progress_timeout (float, optional): Timeout for workflow progress. Defaults to Constant.WORKFLOW_PROGRESS_TIMEOUT.
per_msg_timeout (float, optional): Timeout for receiving individual messages. Defaults to 10.0 seconds.
tx_timeout (float, optional): Timeout for transmitting data. Defaults to 100.0 seconds.
client_shutdown_timeout (float, optional): Timeout for client shutdown. Defaults to 5.0 seconds.
stream_metrics (bool, optional): Whether to stream metrics from Flower client to Flare
analytics_receiver (AnalyticsReceiver, optional): the AnalyticsReceiver to use to process received metrics.
extra_env (dict, optional): optional extra env variables to be passed to Flower client
"""
analytics_receiver = (
analytics_receiver if analytics_receiver else TBAnalyticsReceiver(events=["fed.analytix_log_stats"])
)

super().__init__(
name=name,
flower_content=flower_content,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
database=database,
server_app_args=server_app_args,
superlink_ready_timeout=superlink_ready_timeout,
configure_task_timeout=configure_task_timeout,
start_task_timeout=start_task_timeout,
max_client_op_interval=max_client_op_interval,
progress_timeout=progress_timeout,
per_msg_timeout=per_msg_timeout,
tx_timeout=tx_timeout,
client_shutdown_timeout=client_shutdown_timeout,
stream_metrics=stream_metrics,
analytics_receiver=analytics_receiver,
extra_env=extra_env,
)

0 comments on commit 0106514

Please sign in to comment.