Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Controller Refactor Part 1: separate communication #2390

Merged
merged 3 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions nvflare/apis/controller_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,19 @@ def start_controller(self, fl_ctx: FLContext):
"""
pass

@abstractmethod
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
"""This is the control logic for the RUN.

NOTE: this is running in a separate thread, and its life is the duration of the RUN.

Args:
fl_ctx: the FL context
abort_signal: the abort signal. If triggered, this method stops waiting and returns to the caller.

"""
pass

@abstractmethod
def stop_controller(self, fl_ctx: FLContext):
"""Stops the controller.
Expand Down
2 changes: 2 additions & 0 deletions nvflare/apis/event_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class EventType(object):

BEFORE_PULL_TASK = "_before_pull_task"
AFTER_PULL_TASK = "_after_pull_task"
BEFORE_PROCESS_TASK_REQUEST = "_before_process_task_request"
AFTER_PROCESS_TASK_REQUEST = "_after_process_task_request"
BEFORE_PROCESS_SUBMISSION = "_before_process_submission"
AFTER_PROCESS_SUBMISSION = "_after_process_submission"

Expand Down
949 changes: 42 additions & 907 deletions nvflare/apis/impl/controller.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions nvflare/apis/impl/task_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def start_controller(self, fl_ctx: FLContext):
if not self.task_result_filters:
self.task_result_filters = {}

def control_flow(self, fl_ctx: FLContext):
pass

def stop_controller(self, fl_ctx: FLContext):
pass

Expand Down
257 changes: 257 additions & 0 deletions nvflare/apis/impl/wf_comm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FilterKey, FLContextKey, ReservedKey, ReservedTopic, ReturnCode, SiteType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.task_utils import apply_filters
from nvflare.apis.wf_comm_spec import WFCommSpec
from nvflare.private.fed.utils.fed_utils import get_target_names
from nvflare.private.privacy_manager import Scope
from nvflare.security.logging import secure_format_exception


class WFCommClient(FLComponent, WFCommSpec):
def __init__(
self,
) -> None:
super().__init__()
self.task_data_filters = {}
self.task_result_filters = {}

def broadcast(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
min_responses: int = 0,
wait_time_after_min_received: int = 0,
):
return self.broadcast_and_wait(task, fl_ctx, targets, min_responses, wait_time_after_min_received)

def broadcast_and_wait(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
min_responses: int = 0,
wait_time_after_min_received: int = 0,
abort_signal: Signal = None,
):
engine = fl_ctx.get_engine()
request = task.data
# apply task filters
self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_DATA_FILTER")
fl_ctx.set_prop(FLContextKey.TASK_DATA, task.data, sticky=False, private=True)
self.fire_event(EventType.BEFORE_TASK_DATA_FILTER, fl_ctx)

# # first apply privacy-defined filters
try:
filter_name = Scope.TASK_DATA_FILTERS_NAME
task.data = apply_filters(filter_name, request, fl_ctx, self.task_data_filters, task.name, FilterKey.OUT)
except Exception as e:
self.log_exception(
fl_ctx,
"processing error in task data filter {}; "
"asked client to try again later".format(secure_format_exception(e)),
)
replies = self._make_error_reply(ReturnCode.TASK_DATA_FILTER_ERROR, targets)
return replies

self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_DATA_FILTER")
fl_ctx.set_prop(FLContextKey.TASK_DATA, task.data, sticky=False, private=True)
self.fire_event(EventType.AFTER_TASK_DATA_FILTER, fl_ctx)

target_names = get_target_names(targets)
_, invalid_names = engine.validate_targets(target_names)
if invalid_names:
raise ValueError(f"invalid target(s): {invalid_names}")

# set up ClientTask for each client
for target in targets:
client: Client = self._get_client(target, engine)
client_task = ClientTask(task=task, client=client)
task.client_tasks.append(client_task)
task.last_client_task_map[client_task.id] = client_task

# task_cb_error = self._call_task_cb(task.before_task_sent_cb, client, task, fl_ctx)
# if task_cb_error:
# return self._make_error_reply(ReturnCode.ERROR, targets)

if task.timeout <= 0:
raise ValueError(f"The task timeout must > 0. But got {task.timeout}")

request.set_header(ReservedKey.TASK_NAME, task.name)
replies = engine.send_aux_request(
targets=targets,
topic=ReservedTopic.DO_TASK,
request=request,
timeout=task.timeout,
fl_ctx=fl_ctx,
secure=task.secure,
)

self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_RESULT_FILTER")
self.fire_event(EventType.BEFORE_TASK_RESULT_FILTER, fl_ctx)

for target, reply in replies.items():
# get the client task for the target
for client_task in task.client_tasks:
if client_task.client.name == target:
rc = reply.get_return_code()
if rc and rc == ReturnCode.OK:
# apply result filters
try:
filter_name = Scope.TASK_RESULT_FILTERS_NAME
reply = apply_filters(
filter_name, reply, fl_ctx, self.task_result_filters, task.name, FilterKey.IN
)
except Exception as e:
self.log_exception(
fl_ctx,
"processing error in task result filter {}; ".format(secure_format_exception(e)),
)
error_reply = make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR)
client_task.result = error_reply
break

# assign replies to client task, prepare for the result_received_cb
client_task.result = reply

client: Client = self._get_client(target, engine)
task_cb_error = self._call_task_cb(task.result_received_cb, client, task, fl_ctx)
if task_cb_error:
client_task.result = make_reply(ReturnCode.ERROR)
break
else:
client_task.result = make_reply(ReturnCode.ERROR)

break

# apply task_done_cb
if task.task_done_cb is not None:
try:
task.task_done_cb(task=task, fl_ctx=fl_ctx)
except Exception as e:
self.log_exception(
fl_ctx, f"processing error in task_done_cb error on task {task.name}: {secure_format_exception(e)}"
),
task.completion_status = TaskCompletionStatus.ERROR
task.exception = e
return self._make_error_reply(ReturnCode.ERROR, targets)

replies = {}
for client_task in task.client_tasks:
replies[client_task.client.name] = client_task.result
return replies

def _make_error_reply(self, error_type, targets):
error_reply = make_reply(error_type)
replies = {}
for target in targets:
replies[target] = error_reply
return replies

def _get_client(self, client, engine) -> Client:
if isinstance(client, Client):
return client

if client == SiteType.SERVER:
return Client(SiteType.SERVER, None)

client_obj = None
for _, c in engine.all_clients.items():
if client == c.name:
client_obj = c
return client_obj

def _call_task_cb(self, task_cb, client, task, fl_ctx):
task_cb_error = False
with task.cb_lock:
client_task = self._get_client_task(client, task)

if task_cb is not None:
try:
task_cb(client_task=client_task, fl_ctx=fl_ctx)
except Exception as e:
self.log_exception(
fl_ctx,
f"processing error in {task_cb} on task {client_task.task.name} "
f"({client_task.id}): {secure_format_exception(e)}",
)
# this task cannot proceed anymore
task.completion_status = TaskCompletionStatus.ERROR
task.exception = e
task_cb_error = True

self.logger.debug(f"{task_cb} done on client_task: {client_task}")
self.logger.debug(f"task completion status is {task.completion_status}")
return task_cb_error

def _get_client_task(self, client, task):
client_task = None
for t in task.client_tasks:
if t.client.name == client.name:
client_task = t
return client_task

def send(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
send_order: SendOrder = SendOrder.SEQUENTIAL,
task_assignment_timeout: int = 0,
):
engine = fl_ctx.get_engine()

self._validate_target(engine, targets)

return self.send_and_wait(task, fl_ctx, targets, send_order, task_assignment_timeout)

def send_and_wait(
self,
task: Task,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
send_order: SendOrder = SendOrder.SEQUENTIAL,
task_assignment_timeout: int = 0,
abort_signal: Signal = None,
):
engine = fl_ctx.get_engine()

self._validate_target(engine, targets)

replies = {}
for target in targets:
reply = self.broadcast_and_wait(task, fl_ctx, [target], abort_signal=abort_signal)
replies.update(reply)
return replies

def _validate_target(self, engine, targets):
if len(targets) == 0:
raise ValueError("Must provide a target to send.")
if len(targets) != 1:
raise ValueError("send_and_wait can only send to a single target.")
target_names = get_target_names(targets)
_, invalid_names = engine.validate_targets(target_names)
if invalid_names:
raise ValueError(f"invalid target(s): {invalid_names}")
Loading
Loading