From eb63d2c23cafcfe24fc3e666e30f3da375e3222c Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Thu, 29 Feb 2024 10:47:52 -0800 Subject: [PATCH 1/3] separate communication from controller --- nvflare/apis/controller_spec.py | 13 + nvflare/apis/impl/controller.py | 946 +--------------- nvflare/apis/impl/task_controller.py | 3 + nvflare/apis/impl/wf_comm_client.py | 257 +++++ nvflare/apis/impl/wf_comm_server.py | 1009 +++++++++++++++++ nvflare/apis/responder.py | 14 - nvflare/apis/wf_comm_spec.py | 213 ++++ nvflare/app_common/ccwf/client_ctl.py | 2 +- nvflare/app_common/ccwf/server_ctl.py | 2 +- .../private/fed/server/server_json_config.py | 12 +- nvflare/private/fed/server/server_runner.py | 29 +- 11 files changed, 1560 insertions(+), 940 deletions(-) create mode 100644 nvflare/apis/impl/wf_comm_client.py create mode 100644 nvflare/apis/impl/wf_comm_server.py create mode 100644 nvflare/apis/wf_comm_spec.py diff --git a/nvflare/apis/controller_spec.py b/nvflare/apis/controller_spec.py index 672151d910..2d9f365341 100644 --- a/nvflare/apis/controller_spec.py +++ b/nvflare/apis/controller_spec.py @@ -266,6 +266,19 @@ def start_controller(self, fl_ctx: FLContext): """ pass + @abstractmethod + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + """This is the control logic for the RUN. + + NOTE: this is running in a separate thread, and its life is the duration of the RUN. + + Args: + fl_ctx: the FL context + abort_signal: the abort signal. If triggered, this method stops waiting and returns to the caller. + + """ + pass + @abstractmethod def stop_controller(self, fl_ctx: FLContext): """Stops the controller. diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 0d8ca2f88b..1297f35587 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -11,74 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import threading -import time from abc import ABC -from threading import Lock -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union from nvflare.apis.client import Client from nvflare.apis.controller_spec import ClientTask, ControllerSpec, SendOrder, Task, TaskCompletionStatus -from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext -from nvflare.apis.job_def import job_from_meta -from nvflare.apis.responder import Responder -from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_copy +from nvflare.apis.impl.wf_comm_server import WFCommServer from nvflare.apis.signal import Signal -from nvflare.fuel.utils.config_service import ConfigService -from nvflare.security.logging import secure_format_exception -from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector +from nvflare.apis.wf_comm_spec import WFCommSpec -from .any_relay_manager import AnyRelayTaskManager -from .bcast_manager import BcastForeverTaskManager, BcastTaskManager -from .send_manager import SendTaskManager -from .seq_relay_manager import SequentialRelayTaskManager -from .task_manager import TaskCheckStatus, TaskManager -_TASK_KEY_ENGINE = "___engine" -_TASK_KEY_MANAGER = "___mgr" -_TASK_KEY_DONE = "___done" - -# wait this long since client death report before treating the client as dead -_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD = "dead_client_grace_period" - -# wait this long since job schedule time before starting to check dead clients -_CONFIG_VAR_DEAD_CLIENT_CHECK_LEAD_TIME = "dead_client_check_lead_time" - - -def _check_positive_int(name, value): - if not isinstance(value, int): - raise TypeError("{} must be an instance of int, but got {}.".format(name, type(name))) - if value < 0: - raise ValueError("{} must >= 0.".format(name)) - - -def _check_inputs(task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None]): - if not isinstance(task, Task): - raise TypeError("task must be an instance of Task, but got {}".format(type(task))) - - if not isinstance(fl_ctx, FLContext): - raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) - - if targets is not None: - if not isinstance(targets, list): - raise TypeError("targets must be a list of Client or string, but got {}".format(type(targets))) - - for t in targets: - if not isinstance(t, (Client, str)): - raise TypeError( - "targets must be a list of Client or string, but got element of type {}".format(type(t)) - ) - - -def _get_client_task(target, task: Task): - for ct in task.client_tasks: - if target == ct.client.name: - return ct - return None - - -class Controller(Responder, ControllerSpec, ABC): +class Controller(FLComponent, ControllerSpec, ABC): def __init__(self, task_check_period=0.2): """Manage life cycles of tasks and their destinations. @@ -86,399 +31,18 @@ def __init__(self, task_check_period=0.2): task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.2. """ super().__init__() - self._engine = None - self._tasks = [] # list of standing tasks - self._client_task_map = {} # client_task_id => client_task - self._all_done = False - self._task_lock = Lock() - self._task_monitor = threading.Thread(target=self._monitor_tasks, args=()) self._task_check_period = task_check_period - self._dead_client_reports = {} # clients that reported the job is dead on it: name => report time - self._dead_clients_lock = Lock() # need lock since dead_clients can be modified from different threads - # make sure _check_tasks, process_task_request, process_submission does not interfere with each other - self._controller_lock = Lock() + self.communicator = None - def initialize_run(self, fl_ctx: FLContext): - """Called by runners to initialize controller with information in fl_ctx. - - .. attention:: - - Controller subclasses must not overwrite this method. - - Args: - fl_ctx (FLContext): FLContext information - """ + def set_communicator(self, communicator: WFCommSpec, fl_ctx: FLContext): + communicator.task_check_period = self._task_check_period + self.communicator = communicator engine = fl_ctx.get_engine() if not engine: self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx) return self._engine = engine - self.start_controller(fl_ctx) - self._task_monitor.start() - - def _try_again(self) -> Tuple[str, str, Shareable]: - # TODO: how to tell client no shareable available now? - return "", "", None - - def _set_stats(self, fl_ctx: FLContext): - """Called to set stats into InfoCollector. - - Args: - fl_ctx (FLContext): info collector is retrieved from fl_ctx with InfoCollector.CTX_KEY_STATS_COLLECTOR key - """ - collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None) - if collector: - if not isinstance(collector, GroupInfoCollector): - raise TypeError( - "collector must be an instance of GroupInfoCollector, but got {}".format(type(collector)) - ) - collector.set_info( - group_name=self._name, - info={ - "tasks": {t.name: [ct.client.name for ct in t.client_tasks] for t in self._tasks}, - }, - ) - - def handle_event(self, event_type: str, fl_ctx: FLContext): - """Called when events are fired. - - Args: - event_type (str): all event types, including AppEventType and EventType - fl_ctx (FLContext): FLContext information with current event type - """ - if event_type == InfoCollector.EVENT_TYPE_GET_STATS: - self._set_stats(fl_ctx) - - def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]: - """Called by runner when a client asks for a task. - - .. note:: - - This is called in a separate thread. - - Args: - client (Client): The record of one client requesting tasks - fl_ctx (FLContext): The FLContext associated with this request - - Raises: - TypeError: when client is not an instance of Client - TypeError: when fl_ctx is not an instance of FLContext - TypeError: when any standing task containing an invalid client_task - - Returns: - Tuple[str, str, Shareable]: task_name, an id for the client_task, and the data for this request - """ - with self._controller_lock: - return self._do_process_task_request(client, fl_ctx) - - def _do_process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]: - if not isinstance(client, Client): - raise TypeError("client must be an instance of Client, but got {}".format(type(client))) - - with self._dead_clients_lock: - self._dead_client_reports.pop(client.name, None) - - if not isinstance(fl_ctx, FLContext): - raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) - - client_task_to_send = None - with self._task_lock: - self.logger.debug("self._tasks: {}".format(self._tasks)) - for task in self._tasks: - if task.completion_status is not None: - # this task is finished (and waiting for the monitor to exit it) - continue - - # do we need to send this task to this client? - # note: the task could be sent to a client multiple times (e.g. in relay) - # we only check the last ClientTask sent to the client - client_task_to_check = task.last_client_task_map.get(client.name, None) - self.logger.debug("client_task_to_check: {}".format(client_task_to_check)) - resend_task = False - - if client_task_to_check is not None: - # this client has been sent the task already - if client_task_to_check.result_received_time is None: - # controller has not received result from client - # something wrong happens when client working on this task, so resend the task - resend_task = True - client_task_to_send = client_task_to_check - fl_ctx.set_prop(FLContextKey.IS_CLIENT_TASK_RESEND, True, sticky=False) - - if not resend_task: - # check with the task manager whether to send - manager = task.props[_TASK_KEY_MANAGER] - if client_task_to_check is None: - client_task_to_check = ClientTask(task=task, client=client) - check_status = manager.check_task_send(client_task_to_check, fl_ctx) - self.logger.debug( - "Checking client task: {}, task.client.name: {}".format( - client_task_to_check, client_task_to_check.client.name - ) - ) - self.logger.debug("Check task send get check_status: {}".format(check_status)) - if check_status == TaskCheckStatus.BLOCK: - # do not send this task, and do not check other tasks - return self._try_again() - elif check_status == TaskCheckStatus.NO_BLOCK: - # do not send this task, but continue to check next task - continue - else: - # creates the client_task to be checked for sending - client_task_to_send = ClientTask(client, task) - break - - # NOTE: move task sending process outside the task lock - # This is to minimize the locking time and to avoid potential deadlock: - # the CB could schedule another task, which requires lock - self.logger.debug("Determining based on client_task_to_send: {}".format(client_task_to_send)) - if client_task_to_send is None: - # no task available for this client - return self._try_again() - - # try to send the task - can_send_task = True - task = client_task_to_send.task - with task.cb_lock: - # Note: must guarantee the after_task_sent_cb is always called - # regardless whether the task is sent successfully. - # This is so that the app could clear up things in after_task_sent_cb. - if task.before_task_sent_cb is not None: - try: - task.before_task_sent_cb(client_task=client_task_to_send, fl_ctx=fl_ctx) - except Exception as e: - self.log_exception( - fl_ctx, - "processing error in before_task_sent_cb on task {} ({}): {}".format( - client_task_to_send.task.name, client_task_to_send.id, secure_format_exception(e) - ), - ) - # this task cannot proceed anymore - task.completion_status = TaskCompletionStatus.ERROR - task.exception = e - - self.logger.debug("before_task_sent_cb done on client_task_to_send: {}".format(client_task_to_send)) - self.logger.debug(f"task completion status is {task.completion_status}") - - if task.completion_status is not None: - can_send_task = False - - # remember the task name and data to be sent to the client - # since task.data could be reset by the after_task_sent_cb - task_name = task.name - task_data = task.data - operator = task.operator - - if task.after_task_sent_cb is not None: - try: - task.after_task_sent_cb(client_task=client_task_to_send, fl_ctx=fl_ctx) - except Exception as e: - self.log_exception( - fl_ctx, - "processing error in after_task_sent_cb on task {} ({}): {}".format( - client_task_to_send.task.name, client_task_to_send.id, secure_format_exception(e) - ), - ) - task.completion_status = TaskCompletionStatus.ERROR - task.exception = e - - if task.completion_status is not None: - # NOTE: the CB could cancel the task - can_send_task = False - - if not can_send_task: - return self._try_again() - - self.logger.debug("after_task_sent_cb done on client_task_to_send: {}".format(client_task_to_send)) - - with self._task_lock: - # sent the ClientTask and remember it - now = time.time() - client_task_to_send.task_sent_time = now - client_task_to_send.task_send_count += 1 - - # add task operator to task_data shareable - if operator: - task_data.set_header(key=ReservedHeaderKey.TASK_OPERATOR, value=operator) - - if not resend_task: - task.last_client_task_map[client.name] = client_task_to_send - task.client_tasks.append(client_task_to_send) - self._client_task_map[client_task_to_send.id] = client_task_to_send - - task_data.set_header(ReservedHeaderKey.TASK_ID, client_task_to_send.id) - return task_name, client_task_to_send.id, make_copy(task_data) - - def handle_exception(self, task_id: str, fl_ctx: FLContext) -> None: - """Called to cancel one task as its client_task is causing exception at upper level. - - Args: - task_id (str): an id to the failing client_task - fl_ctx (FLContext): FLContext associated with this client_task - """ - with self._task_lock: - # task_id is the uuid associated with the client_task - client_task = self._client_task_map.get(task_id, None) - self.logger.debug("Handle exception on client_task {} with id {}".format(client_task, task_id)) - - if client_task is None: - # cannot find a standing task on the exception - return - - task = client_task.task - self.cancel_task(task=task, fl_ctx=fl_ctx) - self.log_error(fl_ctx, "task {} is cancelled due to exception".format(task.name)) - - def handle_dead_job(self, client_name: str, fl_ctx: FLContext): - """Called by the Engine to handle the case that the job on the client is dead. - - Args: - client_name: name of the client on which the job is dead - fl_ctx: the FLContext - - """ - # record the report and to be used by the task monitor - with self._dead_clients_lock: - self.log_info(fl_ctx, f"received dead job report from client {client_name}") - if not self._dead_client_reports.get(client_name): - self._dead_client_reports[client_name] = time.time() - - def process_task_check(self, task_id: str, fl_ctx: FLContext): - with self._task_lock: - # task_id is the uuid associated with the client_task - return self._client_task_map.get(task_id, None) - - def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext): - """Called to process a submission from one client. - - .. note:: - - This method is called by a separate thread. - - Args: - client (Client): the client that submitted this task - task_name (str): the task name associated this submission - task_id (str): the id associated with the client_task - result (Shareable): the actual submitted data from the client - fl_ctx (FLContext): the FLContext associated with this submission - - Raises: - TypeError: when client is not an instance of Client - TypeError: when fl_ctx is not an instance of FLContext - TypeError: when result is not an instance of Shareable - ValueError: task_name is not found in the client_task - """ - with self._controller_lock: - self._do_process_submission(client, task_name, task_id, result, fl_ctx) - - def _do_process_submission( - self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext - ): - if not isinstance(client, Client): - raise TypeError("client must be an instance of Client, but got {}".format(type(client))) - - # reset the dead job report! - # note that due to potential race conditions, a client may fail to include the job id in its - # heartbeat (since the job hasn't started at the time of heartbeat report), but then includes - # the job ID later. - with self._dead_clients_lock: - self._dead_client_reports.pop(client.name, None) - - if not isinstance(fl_ctx, FLContext): - raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) - if not isinstance(result, Shareable): - raise TypeError("result must be an instance of Shareable, but got {}".format(type(result))) - - with self._task_lock: - # task_id is the uuid associated with the client_task - client_task = self._client_task_map.get(task_id, None) - self.log_debug(fl_ctx, "Get submission from client task={} id={}".format(client_task, task_id)) - - if client_task is None: - # cannot find a standing task for the submission - self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id)) - self.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx) - return - - task = client_task.task - with task.cb_lock: - if task.name != task_name: - raise ValueError("client specified task name {} doesn't match {}".format(task_name, task.name)) - - if task.completion_status is not None: - # the task is already finished - drop the result - self.log_info(fl_ctx, "task is already finished - submission dropped") - return - - # do client task CB processing outside the lock - # this is because the CB could schedule another task, which requires the lock - client_task.result = result - - manager = task.props[_TASK_KEY_MANAGER] - manager.check_task_result(result, client_task, fl_ctx) - - if task.result_received_cb is not None: - try: - self.log_debug(fl_ctx, "invoking result_received_cb ...") - task.result_received_cb(client_task=client_task, fl_ctx=fl_ctx) - except Exception as e: - # this task cannot proceed anymore - self.log_exception( - fl_ctx, - "processing error in result_received_cb on task {}({}): {}".format( - task_name, task_id, secure_format_exception(e) - ), - ) - task.completion_status = TaskCompletionStatus.ERROR - task.exception = e - else: - self.log_debug(fl_ctx, "no result_received_cb") - - client_task.result_received_time = time.time() - - def _schedule_task( - self, - task: Task, - fl_ctx: FLContext, - manager: TaskManager, - targets: Union[List[Client], List[str], None], - allow_dup_targets: bool = False, - ): - if task.schedule_time is not None: - # this task was scheduled before - # we do not allow a task object to be reused - self.logger.debug("task.schedule_time: {}".format(task.schedule_time)) - raise ValueError("Task was already used. Please create a new task object.") - - # task.targets = targets - target_names = list() - if targets is None: - for client in self._engine.get_clients(): - target_names.append(client.name) - else: - if not isinstance(targets, list): - raise ValueError("task targets must be a list, but got {}".format(type(targets))) - for t in targets: - if isinstance(t, str): - name = t - elif isinstance(t, Client): - name = t.name - else: - raise ValueError("element in targets must be string or Client type, but got {}".format(type(t))) - - if allow_dup_targets or (name not in target_names): - target_names.append(name) - task.targets = target_names - - task.props[_TASK_KEY_MANAGER] = manager - task.props[_TASK_KEY_ENGINE] = self._engine - task.is_standing = True - task.schedule_time = time.time() - - with self._task_lock: - self._tasks.append(task) - self.log_info(fl_ctx, "scheduled task {}".format(task.name)) def broadcast( self, @@ -488,33 +52,7 @@ def broadcast( min_responses: int = 1, wait_time_after_min_received: int = 0, ): - """Schedule a broadcast task. This is a non-blocking call. - - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. - - Args: - task (Task): the task to be scheduled - fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - min_responses (int, optional): the condition to mark this task as completed because enough clients respond with submission. Defaults to 1. - wait_time_after_min_received (int, optional): a grace period for late clients to contribute their submission. 0 means no grace period. - Submission of late clients in the grace period are still collected as valid submission. Defaults to 0. - - Raises: - ValueError: min_responses is greater than the length of targets since this condition will make the task, if allowed to be scheduled, never exit. - """ - _check_inputs(task=task, fl_ctx=fl_ctx, targets=targets) - _check_positive_int("min_responses", min_responses) - _check_positive_int("wait_time_after_min_received", wait_time_after_min_received) - if targets and min_responses > len(targets): - raise ValueError( - "min_responses ({}) must be less than length of targets ({}).".format(min_responses, len(targets)) - ) - - manager = BcastTaskManager( - task=task, min_responses=min_responses, wait_time_after_min_received=wait_time_after_min_received - ) - self._schedule_task(task=task, fl_ctx=fl_ctx, manager=manager, targets=targets) + self.communicator.broadcast(task, fl_ctx, targets, min_responses, wait_time_after_min_received) def broadcast_and_wait( self, @@ -525,42 +63,12 @@ def broadcast_and_wait( wait_time_after_min_received: int = 0, abort_signal: Optional[Signal] = None, ): - """Schedule a broadcast task. This is a blocking call. - - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. - - Args: - task (Task): the task to be scheduled - fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - min_responses (int, optional): the condition to mark this task as completed because enough clients respond with submission. Defaults to 1. - wait_time_after_min_received (int, optional): a grace period for late clients to contribute their submission. 0 means no grace period. - Submission of late clients in the grace period are still collected as valid submission. Defaults to 0. - abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. - """ - self.broadcast( - task=task, - fl_ctx=fl_ctx, - targets=targets, - min_responses=min_responses, - wait_time_after_min_received=wait_time_after_min_received, + self.communicator.broadcast_and_wait( + task, fl_ctx, targets, min_responses, wait_time_after_min_received, abort_signal ) - self.wait_for_task(task, abort_signal) def broadcast_forever(self, task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None] = None): - """Schedule a broadcast task. This is a non-blocking call. - - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. - This broadcast will not end. - - Args: - task (Task): the task to be scheduled - fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - """ - _check_inputs(task=task, fl_ctx=fl_ctx, targets=targets) - manager = BcastForeverTaskManager() - self._schedule_task(task=task, fl_ctx=fl_ctx, manager=manager, targets=targets) + self.communicator.broadcast_forever(task, fl_ctx, targets) def send( self, @@ -570,49 +78,7 @@ def send( send_order: SendOrder = SendOrder.SEQUENTIAL, task_assignment_timeout: int = 0, ): - """Schedule a single task to targets. This is a non-blocking call. - - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. - - Args: - task (Task): the task to be scheduled - fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means - clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. - task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - - Raises: - ValueError: when task_assignment_timeout is greater than task's timeout. - TypeError: send_order is not defined in SendOrder - ValueError: targets is None or an empty list - """ - _check_inputs( - task=task, - fl_ctx=fl_ctx, - targets=targets, - ) - _check_positive_int("task_assignment_timeout", task_assignment_timeout) - if task.timeout and task_assignment_timeout and task_assignment_timeout > task.timeout: - raise ValueError( - "task_assignment_timeout ({}) needs to be less than or equal to task.timeout ({}).".format( - task_assignment_timeout, task.timeout - ) - ) - if not isinstance(send_order, SendOrder): - raise TypeError("send_order must be in Enum SendOrder, but got {}".format(type(send_order))) - - # targets must be provided - if targets is None or len(targets) == 0: - raise ValueError("Targets must be provided for send.") - - manager = SendTaskManager(task, send_order, task_assignment_timeout) - self._schedule_task( - task=task, - fl_ctx=fl_ctx, - manager=manager, - targets=targets, - ) + self.communicator.send(task, fl_ctx, targets, send_order, task_assignment_timeout) def send_and_wait( self, @@ -623,84 +89,7 @@ def send_and_wait( task_assignment_timeout: int = 0, abort_signal: Signal = None, ): - """Schedule a single task to targets. This is a blocking call. - - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. - - Args: - task (Task): the task to be scheduled - fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means - clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. - task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. - - """ - self.send( - task=task, - fl_ctx=fl_ctx, - targets=targets, - send_order=send_order, - task_assignment_timeout=task_assignment_timeout, - ) - self.wait_for_task(task, abort_signal) - - def get_num_standing_tasks(self) -> int: - """Get the number of tasks that are currently standing. - - Returns: - int: length of the list of standing tasks - """ - return len(self._tasks) - - def cancel_task( - self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None - ): - """Cancel the specified task. - - Change the task completion_status, which will inform task monitor to clean up this task - - note:: - - We only mark the task as completed and leave it to the task monitor to clean up. This is to avoid potential deadlock of task_lock. - - Args: - task (Task): the task to be cancelled - completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED. - fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None. - """ - task.completion_status = completion_status - - def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): - """Cancel all standing tasks in this controller. - - Args: - completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED. - fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None. - """ - with self._task_lock: - for t in self._tasks: - t.completion_status = completion_status - - def finalize_run(self, fl_ctx: FLContext): - """Do cleanup of the coordinator implementation. - - .. attention:: - - Subclass controllers should not overwrite finalize_run. - - Args: - fl_ctx (FLContext): FLContext associated with this action - """ - self.cancel_all_tasks() # unconditionally cancel all tasks - self._all_done = True - try: - if self._task_monitor.is_alive(): - self._task_monitor.join() - except RuntimeError: - self.log_debug(fl_ctx, "unable to join monitor thread (not started?)") - self.stop_controller(fl_ctx) + self.communicator.send_and_wait(task, fl_ctx, targets, send_order, task_assignment_timeout, abort_signal) def relay( self, @@ -712,72 +101,8 @@ def relay( task_result_timeout: int = 0, dynamic_targets: bool = True, ): - """Schedule a single task to targets in one-after-another style. This is a non-blocking call. - - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. - - Args: - task (Task): the task to be scheduled - fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. - SEQUENTIAL means the order in targets is enforced. - ANY means any clients that are inside the targets and haven't received the task are eligible. Defaults to SendOrder.SEQUENTIAL. - task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - task_result_timeout (int, optional): how long to wait for current working client to reply its result. Defaults to 0. - dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. Defaults to True. - - Raises: - ValueError: when task_assignment_timeout is greater than task's timeout - ValueError: when task_result_timeout is greater than task's timeout - TypeError: send_order is not defined in SendOrder - TypeError: when dynamic_targets is not a boolean variable - ValueError: targets is None or an empty list but dynamic_targets is False - """ - _check_inputs( - task=task, - fl_ctx=fl_ctx, - targets=targets, - ) - _check_positive_int("task_assignment_timeout", task_assignment_timeout) - _check_positive_int("task_result_timeout", task_result_timeout) - if task.timeout and task_assignment_timeout and task_assignment_timeout > task.timeout: - raise ValueError( - "task_assignment_timeout ({}) needs to be less than or equal to task.timeout ({}).".format( - task_assignment_timeout, task.timeout - ) - ) - if task.timeout and task_result_timeout and task_result_timeout > task.timeout: - raise ValueError( - "task_result_timeout ({}) needs to be less than or equal to task.timeout ({}).".format( - task_result_timeout, task.timeout - ) - ) - if not isinstance(send_order, SendOrder): - raise TypeError("send_order must be in Enum SendOrder, but got {}".format(type(send_order))) - if not isinstance(dynamic_targets, bool): - raise TypeError("dynamic_targets must be an instance of bool, but got {}".format(type(dynamic_targets))) - if targets is None and dynamic_targets is False: - raise ValueError("Need to provide targets when dynamic_targets is set to False.") - - if send_order == SendOrder.SEQUENTIAL: - manager = SequentialRelayTaskManager( - task=task, - task_assignment_timeout=task_assignment_timeout, - task_result_timeout=task_result_timeout, - dynamic_targets=dynamic_targets, - ) - else: - manager = AnyRelayTaskManager( - task=task, task_result_timeout=task_result_timeout, dynamic_targets=dynamic_targets - ) - - self._schedule_task( - task=task, - fl_ctx=fl_ctx, - manager=manager, - targets=targets, - allow_dup_targets=True, + self.communicator.relay( + task, fl_ctx, targets, send_order, task_assignment_timeout, task_result_timeout, dynamic_targets ) def relay_and_wait( @@ -791,221 +116,30 @@ def relay_and_wait( dynamic_targets: bool = True, abort_signal: Optional[Signal] = None, ): - """Schedule a single task to targets in one-after-another style. This is a blocking call. - - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. - - Args: - task (Task): the task to be scheduled - fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means - clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. - task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - task_result_timeout (int, optional): how long to wait for current working client to reply its result. Defaults to 0. - dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. Defaults to True. - abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. - """ - self.relay( - task=task, - fl_ctx=fl_ctx, - targets=targets, - send_order=send_order, - task_assignment_timeout=task_assignment_timeout, - task_result_timeout=task_result_timeout, - dynamic_targets=dynamic_targets, + self.communicator.relay_and_wait( + task, + fl_ctx, + targets, + send_order, + task_assignment_timeout, + task_result_timeout, + dynamic_targets, + abort_signal, ) - self.wait_for_task(task, abort_signal) - - def _monitor_tasks(self): - while not self._all_done: - should_abort_job = self._job_policy_violated() - if not should_abort_job: - self._check_tasks() - else: - with self._engine.new_context() as fl_ctx: - self.system_panic("Aborting job due to deployment policy violation", fl_ctx) - return - time.sleep(self._task_check_period) - - def _check_tasks(self): - with self._controller_lock: - self._do_check_tasks() - - def _do_check_tasks(self): - exit_tasks = [] - with self._task_lock: - for task in self._tasks: - if task.completion_status is not None: - exit_tasks.append(task) - continue - - # check the task-specific exit condition - manager = task.props[_TASK_KEY_MANAGER] - if manager is not None: - if not isinstance(manager, TaskManager): - raise TypeError( - "manager in task must be an instance of TaskManager, but got {}".format(manager) - ) - should_exit, exit_status = manager.check_task_exit(task) - self.logger.debug("should_exit: {}, exit_status: {}".format(should_exit, exit_status)) - if should_exit: - task.completion_status = exit_status - exit_tasks.append(task) - continue - - # check if task timeout - if task.timeout and time.time() - task.schedule_time >= task.timeout: - task.completion_status = TaskCompletionStatus.TIMEOUT - exit_tasks.append(task) - continue - - # check whether clients that the task is waiting are all dead - dead_clients = self._get_task_dead_clients(task) - if dead_clients: - self.logger.info(f"client {dead_clients} is dead - set task {task.name} to TIMEOUT") - task.completion_status = TaskCompletionStatus.CLIENT_DEAD - exit_tasks.append(task) - continue - - for exit_task in exit_tasks: - exit_task.is_standing = False - self.logger.debug( - "Removing task={}, completion_status={}".format(exit_task, exit_task.completion_status) - ) - self._tasks.remove(exit_task) - for client_task in exit_task.client_tasks: - self.logger.debug("Removing client_task with id={}".format(client_task.id)) - self._client_task_map.pop(client_task.id) - - # do the task exit processing outside the lock to minimize the locking time - # and to avoid potential deadlock since the CB could schedule another task - if len(exit_tasks) <= 0: - return - - with self._engine.new_context() as fl_ctx: - for exit_task in exit_tasks: - with exit_task.cb_lock: - self.log_info( - fl_ctx, "task {} exit with status {}".format(exit_task.name, exit_task.completion_status) - ) - if exit_task.task_done_cb is not None: - try: - exit_task.task_done_cb(task=exit_task, fl_ctx=fl_ctx) - except Exception as e: - self.log_exception( - fl_ctx, - "processing error in task_done_cb error on task {}: {}".format( - exit_task.name, secure_format_exception(e) - ), - ) - exit_task.completion_status = TaskCompletionStatus.ERROR - exit_task.exception = e - - def _get_task_dead_clients(self, task: Task): - """ - See whether the task is only waiting for response from a dead client - """ - now = time.time() - lead_time = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_CHECK_LEAD_TIME, default=30.0) - if now - task.schedule_time < lead_time: - # due to potential race conditions, we'll wait for at least 1 minute after the task - # is started before checking dead clients. - return None - - dead_clients = [] - with self._dead_clients_lock: - for target in task.targets: - ct = _get_client_task(target, task) - if ct is not None and ct.result_received_time: - # response has been received from this client - continue - - # either we have not sent the task to this client or we have not received response - # is the client already dead? - if self._client_still_alive(target): - # this client is still alive - # we let the task continue its course since we still have live clients - return None - else: - # this client is dead - remember it - dead_clients.append(target) - - return dead_clients - - @staticmethod - def _process_finished_task(task, func): - def wrap(*args, **kwargs): - if func: - func(*args, **kwargs) - task.props[_TASK_KEY_DONE] = True - - return wrap - - def wait_for_task(self, task: Task, abort_signal: Signal): - task.props[_TASK_KEY_DONE] = False - task.task_done_cb = self._process_finished_task(task=task, func=task.task_done_cb) - while True: - if task.completion_status is not None: - break - - if abort_signal and abort_signal.triggered: - self.cancel_task(task, fl_ctx=None, completion_status=TaskCompletionStatus.ABORTED) - break - - task_done = task.props[_TASK_KEY_DONE] - if task_done: - break - time.sleep(self._task_check_period) - - def _job_policy_violated(self): - if not self._engine: - return False - - with self._engine.new_context() as fl_ctx: - clients = self._engine.get_clients() - with self._dead_clients_lock: - alive_clients = [] - dead_clients = [] - - for client in clients: - if self._client_still_alive(client.name): - alive_clients.append(client.name) - else: - dead_clients.append(client.name) - - if not dead_clients: - return False - - if not alive_clients: - self.log_error(fl_ctx, f"All clients are dead: {dead_clients}") - return True - - job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) - job = job_from_meta(job_meta) - if len(alive_clients) < job.min_sites: - self.log_error(fl_ctx, f"Alive clients {len(alive_clients)} < required min {job.min_sites}") - return True - - # check required clients: - if dead_clients and job.required_sites: - dead_required_clients = [c for c in dead_clients if c in job.required_sites] - if dead_required_clients: - self.log_error(fl_ctx, f"Required client(s) dead: {dead_required_clients}") - return True - return False - - def _client_still_alive(self, client_name): - now = time.time() - report_time = self._dead_client_reports.get(client_name, None) - grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=30.0) + def get_num_standing_tasks(self) -> int: + if not isinstance(self.communicator, WFCommServer): + raise NotImplementedError + return self.communicator.get_num_standing_tasks() - if not report_time: - # this client is still alive - return True - elif now - report_time < grace_period: - # this report is still fresh - consider the client to be still alive - return True + def cancel_task( + self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None + ): + if not isinstance(self.communicator, WFCommServer): + raise NotImplementedError + self.communicator.cancel_task(task, completion_status, fl_ctx) - return False + def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): + if not isinstance(self.communicator, WFCommServer): + raise NotImplementedError + self.communicator.cancel_all_tasks(completion_status, fl_ctx) diff --git a/nvflare/apis/impl/task_controller.py b/nvflare/apis/impl/task_controller.py index 050b35817d..d18cd5a96a 100644 --- a/nvflare/apis/impl/task_controller.py +++ b/nvflare/apis/impl/task_controller.py @@ -51,6 +51,9 @@ def start_controller(self, fl_ctx: FLContext): if not self.task_result_filters: self.task_result_filters = {} + def control_flow(self, fl_ctx: FLContext): + pass + def stop_controller(self, fl_ctx: FLContext): pass diff --git a/nvflare/apis/impl/wf_comm_client.py b/nvflare/apis/impl/wf_comm_client.py new file mode 100644 index 0000000000..a8753f5b64 --- /dev/null +++ b/nvflare/apis/impl/wf_comm_client.py @@ -0,0 +1,257 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_constant import FilterKey, FLContextKey, ReservedKey, ReservedTopic, ReturnCode, SiteType +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import make_reply +from nvflare.apis.signal import Signal +from nvflare.apis.utils.task_utils import apply_filters +from nvflare.apis.wf_comm_spec import WFCommSpec +from nvflare.private.fed.utils.fed_utils import get_target_names +from nvflare.private.privacy_manager import Scope +from nvflare.security.logging import secure_format_exception + + +class WFCommClient(FLComponent, WFCommSpec): + def __init__( + self, + ) -> None: + super().__init__() + self.task_data_filters = {} + self.task_result_filters = {} + + def broadcast( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + min_responses: int = 0, + wait_time_after_min_received: int = 0, + ): + return self.broadcast_and_wait(task, fl_ctx, targets, min_responses, wait_time_after_min_received) + + def broadcast_and_wait( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + min_responses: int = 0, + wait_time_after_min_received: int = 0, + abort_signal: Signal = None, + ): + engine = fl_ctx.get_engine() + request = task.data + # apply task filters + self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_DATA_FILTER") + fl_ctx.set_prop(FLContextKey.TASK_DATA, task.data, sticky=False, private=True) + self.fire_event(EventType.BEFORE_TASK_DATA_FILTER, fl_ctx) + + # # first apply privacy-defined filters + try: + filter_name = Scope.TASK_DATA_FILTERS_NAME + task.data = apply_filters(filter_name, request, fl_ctx, self.task_data_filters, task.name, FilterKey.OUT) + except Exception as e: + self.log_exception( + fl_ctx, + "processing error in task data filter {}; " + "asked client to try again later".format(secure_format_exception(e)), + ) + replies = self._make_error_reply(ReturnCode.TASK_DATA_FILTER_ERROR, targets) + return replies + + self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_DATA_FILTER") + fl_ctx.set_prop(FLContextKey.TASK_DATA, task.data, sticky=False, private=True) + self.fire_event(EventType.AFTER_TASK_DATA_FILTER, fl_ctx) + + target_names = get_target_names(targets) + _, invalid_names = engine.validate_targets(target_names) + if invalid_names: + raise ValueError(f"invalid target(s): {invalid_names}") + + # set up ClientTask for each client + for target in targets: + client: Client = self._get_client(target, engine) + client_task = ClientTask(task=task, client=client) + task.client_tasks.append(client_task) + task.last_client_task_map[client_task.id] = client_task + + # task_cb_error = self._call_task_cb(task.before_task_sent_cb, client, task, fl_ctx) + # if task_cb_error: + # return self._make_error_reply(ReturnCode.ERROR, targets) + + if task.timeout <= 0: + raise ValueError(f"The task timeout must > 0. But got {task.timeout}") + + request.set_header(ReservedKey.TASK_NAME, task.name) + replies = engine.send_aux_request( + targets=targets, + topic=ReservedTopic.DO_TASK, + request=request, + timeout=task.timeout, + fl_ctx=fl_ctx, + secure=task.secure, + ) + + self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_RESULT_FILTER") + self.fire_event(EventType.BEFORE_TASK_RESULT_FILTER, fl_ctx) + + for target, reply in replies.items(): + # get the client task for the target + for client_task in task.client_tasks: + if client_task.client.name == target: + rc = reply.get_return_code() + if rc and rc == ReturnCode.OK: + # apply result filters + try: + filter_name = Scope.TASK_RESULT_FILTERS_NAME + reply = apply_filters( + filter_name, reply, fl_ctx, self.task_result_filters, task.name, FilterKey.IN + ) + except Exception as e: + self.log_exception( + fl_ctx, + "processing error in task result filter {}; ".format(secure_format_exception(e)), + ) + error_reply = make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR) + client_task.result = error_reply + break + + # assign replies to client task, prepare for the result_received_cb + client_task.result = reply + + client: Client = self._get_client(target, engine) + task_cb_error = self._call_task_cb(task.result_received_cb, client, task, fl_ctx) + if task_cb_error: + client_task.result = make_reply(ReturnCode.ERROR) + break + else: + client_task.result = make_reply(ReturnCode.ERROR) + + break + + # apply task_done_cb + if task.task_done_cb is not None: + try: + task.task_done_cb(task=task, fl_ctx=fl_ctx) + except Exception as e: + self.log_exception( + fl_ctx, f"processing error in task_done_cb error on task {task.name}: {secure_format_exception(e)}" + ), + task.completion_status = TaskCompletionStatus.ERROR + task.exception = e + return self._make_error_reply(ReturnCode.ERROR, targets) + + replies = {} + for client_task in task.client_tasks: + replies[client_task.client.name] = client_task.result + return replies + + def _make_error_reply(self, error_type, targets): + error_reply = make_reply(error_type) + replies = {} + for target in targets: + replies[target] = error_reply + return replies + + def _get_client(self, client, engine) -> Client: + if isinstance(client, Client): + return client + + if client == SiteType.SERVER: + return Client(SiteType.SERVER, None) + + client_obj = None + for _, c in engine.all_clients.items(): + if client == c.name: + client_obj = c + return client_obj + + def _call_task_cb(self, task_cb, client, task, fl_ctx): + task_cb_error = False + with task.cb_lock: + client_task = self._get_client_task(client, task) + + if task_cb is not None: + try: + task_cb(client_task=client_task, fl_ctx=fl_ctx) + except Exception as e: + self.log_exception( + fl_ctx, + f"processing error in {task_cb} on task {client_task.task.name} " + f"({client_task.id}): {secure_format_exception(e)}", + ) + # this task cannot proceed anymore + task.completion_status = TaskCompletionStatus.ERROR + task.exception = e + task_cb_error = True + + self.logger.debug(f"{task_cb} done on client_task: {client_task}") + self.logger.debug(f"task completion status is {task.completion_status}") + return task_cb_error + + def _get_client_task(self, client, task): + client_task = None + for t in task.client_tasks: + if t.client.name == client.name: + client_task = t + return client_task + + def send( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order: SendOrder = SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + ): + engine = fl_ctx.get_engine() + + self._validate_target(engine, targets) + + return self.send_and_wait(task, fl_ctx, targets, send_order, task_assignment_timeout) + + def send_and_wait( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order: SendOrder = SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + abort_signal: Signal = None, + ): + engine = fl_ctx.get_engine() + + self._validate_target(engine, targets) + + replies = {} + for target in targets: + reply = self.broadcast_and_wait(task, fl_ctx, [target], abort_signal=abort_signal) + replies.update(reply) + return replies + + def _validate_target(self, engine, targets): + if len(targets) == 0: + raise ValueError("Must provide a target to send.") + if len(targets) != 1: + raise ValueError("send_and_wait can only send to a single target.") + target_names = get_target_names(targets) + _, invalid_names = engine.validate_targets(target_names) + if invalid_names: + raise ValueError(f"invalid target(s): {invalid_names}") diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py new file mode 100644 index 0000000000..b15f555c1a --- /dev/null +++ b/nvflare/apis/impl/wf_comm_server.py @@ -0,0 +1,1009 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time +from threading import Lock +from typing import List, Optional, Tuple, Union + +from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import job_from_meta +from nvflare.apis.responder import Responder +from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_copy +from nvflare.apis.signal import Signal +from nvflare.apis.wf_comm_spec import WFCommSpec +from nvflare.fuel.utils.config_service import ConfigService +from nvflare.security.logging import secure_format_exception +from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector + +from .any_relay_manager import AnyRelayTaskManager +from .bcast_manager import BcastForeverTaskManager, BcastTaskManager +from .send_manager import SendTaskManager +from .seq_relay_manager import SequentialRelayTaskManager +from .task_manager import TaskCheckStatus, TaskManager + +_TASK_KEY_ENGINE = "___engine" +_TASK_KEY_MANAGER = "___mgr" +_TASK_KEY_DONE = "___done" + +# wait this long since client death report before treating the client as dead +_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD = "dead_client_grace_period" + +# wait this long since job schedule time before starting to check dead clients +_CONFIG_VAR_DEAD_CLIENT_CHECK_LEAD_TIME = "dead_client_check_lead_time" + + +def _check_positive_int(name, value): + if not isinstance(value, int): + raise TypeError("{} must be an instance of int, but got {}.".format(name, type(name))) + if value < 0: + raise ValueError("{} must >= 0.".format(name)) + + +def _check_inputs(task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None]): + if not isinstance(task, Task): + raise TypeError("task must be an instance of Task, but got {}".format(type(task))) + + if not isinstance(fl_ctx, FLContext): + raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) + + if targets is not None: + if not isinstance(targets, list): + raise TypeError("targets must be a list of Client or string, but got {}".format(type(targets))) + + for t in targets: + if not isinstance(t, (Client, str)): + raise TypeError( + "targets must be a list of Client or string, but got element of type {}".format(type(t)) + ) + + +def _get_client_task(target, task: Task): + for ct in task.client_tasks: + if target == ct.client.name: + return ct + return None + + +class WFCommServer(Responder, WFCommSpec): + def __init__(self, task_check_period=0.2): + """Manage life cycles of tasks and their destinations. + + Args: + task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.2. + """ + super().__init__() + self._engine = None + self._tasks = [] # list of standing tasks + self._client_task_map = {} # client_task_id => client_task + self._all_done = False + self._task_lock = Lock() + self._task_monitor = threading.Thread(target=self._monitor_tasks, args=()) + self._task_check_period = task_check_period + self._dead_client_reports = {} # clients that reported the job is dead on it: name => report time + self._dead_clients_lock = Lock() # need lock since dead_clients can be modified from different threads + # make sure _check_tasks, process_task_request, process_submission does not interfere with each other + self._controller_lock = Lock() + + def initialize_run(self, fl_ctx: FLContext): + """Called by runners to initialize controller with information in fl_ctx. + + .. attention:: + + Controller subclasses must not overwrite this method. + + Args: + fl_ctx (FLContext): FLContext information + """ + engine = fl_ctx.get_engine() + if not engine: + self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx) + return + + self._engine = engine + self._task_monitor.start() + + def _try_again(self) -> Tuple[str, str, Shareable]: + # TODO: how to tell client no shareable available now? + return "", "", None + + def _set_stats(self, fl_ctx: FLContext): + """Called to set stats into InfoCollector. + + Args: + fl_ctx (FLContext): info collector is retrieved from fl_ctx with InfoCollector.CTX_KEY_STATS_COLLECTOR key + """ + collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None) + if collector: + if not isinstance(collector, GroupInfoCollector): + raise TypeError( + "collector must be an instance of GroupInfoCollector, but got {}".format(type(collector)) + ) + collector.set_info( + group_name=self._name, + info={ + "tasks": {t.name: [ct.client.name for ct in t.client_tasks] for t in self._tasks}, + }, + ) + + def handle_event(self, event_type: str, fl_ctx: FLContext): + """Called when events are fired. + + Args: + event_type (str): all event types, including AppEventType and EventType + fl_ctx (FLContext): FLContext information with current event type + """ + if event_type == InfoCollector.EVENT_TYPE_GET_STATS: + self._set_stats(fl_ctx) + + def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]: + """Called by runner when a client asks for a task. + + .. note:: + + This is called in a separate thread. + + Args: + client (Client): The record of one client requesting tasks + fl_ctx (FLContext): The FLContext associated with this request + + Raises: + TypeError: when client is not an instance of Client + TypeError: when fl_ctx is not an instance of FLContext + TypeError: when any standing task containing an invalid client_task + + Returns: + Tuple[str, str, Shareable]: task_name, an id for the client_task, and the data for this request + """ + with self._controller_lock: + return self._do_process_task_request(client, fl_ctx) + + def _do_process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]: + if not isinstance(client, Client): + raise TypeError("client must be an instance of Client, but got {}".format(type(client))) + + with self._dead_clients_lock: + self._dead_client_reports.pop(client.name, None) + + if not isinstance(fl_ctx, FLContext): + raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) + + client_task_to_send = None + with self._task_lock: + self.logger.debug("self._tasks: {}".format(self._tasks)) + for task in self._tasks: + if task.completion_status is not None: + # this task is finished (and waiting for the monitor to exit it) + continue + + # do we need to send this task to this client? + # note: the task could be sent to a client multiple times (e.g. in relay) + # we only check the last ClientTask sent to the client + client_task_to_check = task.last_client_task_map.get(client.name, None) + self.logger.debug("client_task_to_check: {}".format(client_task_to_check)) + resend_task = False + + if client_task_to_check is not None: + # this client has been sent the task already + if client_task_to_check.result_received_time is None: + # controller has not received result from client + # something wrong happens when client working on this task, so resend the task + resend_task = True + client_task_to_send = client_task_to_check + fl_ctx.set_prop(FLContextKey.IS_CLIENT_TASK_RESEND, True, sticky=False) + + if not resend_task: + # check with the task manager whether to send + manager = task.props[_TASK_KEY_MANAGER] + if client_task_to_check is None: + client_task_to_check = ClientTask(task=task, client=client) + check_status = manager.check_task_send(client_task_to_check, fl_ctx) + self.logger.debug( + "Checking client task: {}, task.client.name: {}".format( + client_task_to_check, client_task_to_check.client.name + ) + ) + self.logger.debug("Check task send get check_status: {}".format(check_status)) + if check_status == TaskCheckStatus.BLOCK: + # do not send this task, and do not check other tasks + return self._try_again() + elif check_status == TaskCheckStatus.NO_BLOCK: + # do not send this task, but continue to check next task + continue + else: + # creates the client_task to be checked for sending + client_task_to_send = ClientTask(client, task) + break + + # NOTE: move task sending process outside the task lock + # This is to minimize the locking time and to avoid potential deadlock: + # the CB could schedule another task, which requires lock + self.logger.debug("Determining based on client_task_to_send: {}".format(client_task_to_send)) + if client_task_to_send is None: + # no task available for this client + return self._try_again() + + # try to send the task + can_send_task = True + task = client_task_to_send.task + with task.cb_lock: + # Note: must guarantee the after_task_sent_cb is always called + # regardless whether the task is sent successfully. + # This is so that the app could clear up things in after_task_sent_cb. + if task.before_task_sent_cb is not None: + try: + task.before_task_sent_cb(client_task=client_task_to_send, fl_ctx=fl_ctx) + except Exception as e: + self.log_exception( + fl_ctx, + "processing error in before_task_sent_cb on task {} ({}): {}".format( + client_task_to_send.task.name, client_task_to_send.id, secure_format_exception(e) + ), + ) + # this task cannot proceed anymore + task.completion_status = TaskCompletionStatus.ERROR + task.exception = e + + self.logger.debug("before_task_sent_cb done on client_task_to_send: {}".format(client_task_to_send)) + self.logger.debug(f"task completion status is {task.completion_status}") + + if task.completion_status is not None: + can_send_task = False + + # remember the task name and data to be sent to the client + # since task.data could be reset by the after_task_sent_cb + task_name = task.name + task_data = task.data + operator = task.operator + + if task.after_task_sent_cb is not None: + try: + task.after_task_sent_cb(client_task=client_task_to_send, fl_ctx=fl_ctx) + except Exception as e: + self.log_exception( + fl_ctx, + "processing error in after_task_sent_cb on task {} ({}): {}".format( + client_task_to_send.task.name, client_task_to_send.id, secure_format_exception(e) + ), + ) + task.completion_status = TaskCompletionStatus.ERROR + task.exception = e + + if task.completion_status is not None: + # NOTE: the CB could cancel the task + can_send_task = False + + if not can_send_task: + return self._try_again() + + self.logger.debug("after_task_sent_cb done on client_task_to_send: {}".format(client_task_to_send)) + + with self._task_lock: + # sent the ClientTask and remember it + now = time.time() + client_task_to_send.task_sent_time = now + client_task_to_send.task_send_count += 1 + + # add task operator to task_data shareable + if operator: + task_data.set_header(key=ReservedHeaderKey.TASK_OPERATOR, value=operator) + + if not resend_task: + task.last_client_task_map[client.name] = client_task_to_send + task.client_tasks.append(client_task_to_send) + self._client_task_map[client_task_to_send.id] = client_task_to_send + + task_data.set_header(ReservedHeaderKey.TASK_ID, client_task_to_send.id) + return task_name, client_task_to_send.id, make_copy(task_data) + + def handle_exception(self, task_id: str, fl_ctx: FLContext) -> None: + """Called to cancel one task as its client_task is causing exception at upper level. + + Args: + task_id (str): an id to the failing client_task + fl_ctx (FLContext): FLContext associated with this client_task + """ + with self._task_lock: + # task_id is the uuid associated with the client_task + client_task = self._client_task_map.get(task_id, None) + self.logger.debug("Handle exception on client_task {} with id {}".format(client_task, task_id)) + + if client_task is None: + # cannot find a standing task on the exception + return + + task = client_task.task + self.cancel_task(task=task, fl_ctx=fl_ctx) + self.log_error(fl_ctx, "task {} is cancelled due to exception".format(task.name)) + + def handle_dead_job(self, client_name: str, fl_ctx: FLContext): + """Called by the Engine to handle the case that the job on the client is dead. + + Args: + client_name: name of the client on which the job is dead + fl_ctx: the FLContext + + """ + # record the report and to be used by the task monitor + with self._dead_clients_lock: + self.log_info(fl_ctx, f"received dead job report from client {client_name}") + if not self._dead_client_reports.get(client_name): + self._dead_client_reports[client_name] = time.time() + + def process_task_check(self, task_id: str, fl_ctx: FLContext): + with self._task_lock: + # task_id is the uuid associated with the client_task + return self._client_task_map.get(task_id, None) + + def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext): + """Called to process a submission from one client. + + .. note:: + + This method is called by a separate thread. + + Args: + client (Client): the client that submitted this task + task_name (str): the task name associated this submission + task_id (str): the id associated with the client_task + result (Shareable): the actual submitted data from the client + fl_ctx (FLContext): the FLContext associated with this submission + + Raises: + TypeError: when client is not an instance of Client + TypeError: when fl_ctx is not an instance of FLContext + TypeError: when result is not an instance of Shareable + ValueError: task_name is not found in the client_task + """ + with self._controller_lock: + self._do_process_submission(client, task_name, task_id, result, fl_ctx) + + def _do_process_submission( + self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext + ): + if not isinstance(client, Client): + raise TypeError("client must be an instance of Client, but got {}".format(type(client))) + + # reset the dead job report! + # note that due to potential race conditions, a client may fail to include the job id in its + # heartbeat (since the job hasn't started at the time of heartbeat report), but then includes + # the job ID later. + with self._dead_clients_lock: + self._dead_client_reports.pop(client.name, None) + + if not isinstance(fl_ctx, FLContext): + raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) + if not isinstance(result, Shareable): + raise TypeError("result must be an instance of Shareable, but got {}".format(type(result))) + + with self._task_lock: + # task_id is the uuid associated with the client_task + client_task = self._client_task_map.get(task_id, None) + self.log_debug(fl_ctx, "Get submission from client task={} id={}".format(client_task, task_id)) + + if client_task is None: + # cannot find a standing task for the submission + self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id)) + self.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx) + return + + task = client_task.task + with task.cb_lock: + if task.name != task_name: + raise ValueError("client specified task name {} doesn't match {}".format(task_name, task.name)) + + if task.completion_status is not None: + # the task is already finished - drop the result + self.log_info(fl_ctx, "task is already finished - submission dropped") + return + + # do client task CB processing outside the lock + # this is because the CB could schedule another task, which requires the lock + client_task.result = result + + manager = task.props[_TASK_KEY_MANAGER] + manager.check_task_result(result, client_task, fl_ctx) + + if task.result_received_cb is not None: + try: + self.log_debug(fl_ctx, "invoking result_received_cb ...") + task.result_received_cb(client_task=client_task, fl_ctx=fl_ctx) + except Exception as e: + # this task cannot proceed anymore + self.log_exception( + fl_ctx, + "processing error in result_received_cb on task {}({}): {}".format( + task_name, task_id, secure_format_exception(e) + ), + ) + task.completion_status = TaskCompletionStatus.ERROR + task.exception = e + else: + self.log_debug(fl_ctx, "no result_received_cb") + + client_task.result_received_time = time.time() + + def _schedule_task( + self, + task: Task, + fl_ctx: FLContext, + manager: TaskManager, + targets: Union[List[Client], List[str], None], + allow_dup_targets: bool = False, + ): + if task.schedule_time is not None: + # this task was scheduled before + # we do not allow a task object to be reused + self.logger.debug("task.schedule_time: {}".format(task.schedule_time)) + raise ValueError("Task was already used. Please create a new task object.") + + # task.targets = targets + target_names = list() + if targets is None: + for client in self._engine.get_clients(): + target_names.append(client.name) + else: + if not isinstance(targets, list): + raise ValueError("task targets must be a list, but got {}".format(type(targets))) + for t in targets: + if isinstance(t, str): + name = t + elif isinstance(t, Client): + name = t.name + else: + raise ValueError("element in targets must be string or Client type, but got {}".format(type(t))) + + if allow_dup_targets or (name not in target_names): + target_names.append(name) + task.targets = target_names + + task.props[_TASK_KEY_MANAGER] = manager + task.props[_TASK_KEY_ENGINE] = self._engine + task.is_standing = True + task.schedule_time = time.time() + + with self._task_lock: + self._tasks.append(task) + self.log_info(fl_ctx, "scheduled task {}".format(task.name)) + + def broadcast( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + min_responses: int = 1, + wait_time_after_min_received: int = 0, + ): + """Schedule a broadcast task. This is a non-blocking call. + + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + + Args: + task (Task): the task to be scheduled + fl_ctx (FLContext): FLContext associated with this task + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + min_responses (int, optional): the condition to mark this task as completed because enough clients respond with submission. Defaults to 1. + wait_time_after_min_received (int, optional): a grace period for late clients to contribute their submission. 0 means no grace period. + Submission of late clients in the grace period are still collected as valid submission. Defaults to 0. + + Raises: + ValueError: min_responses is greater than the length of targets since this condition will make the task, if allowed to be scheduled, never exit. + """ + _check_inputs(task=task, fl_ctx=fl_ctx, targets=targets) + _check_positive_int("min_responses", min_responses) + _check_positive_int("wait_time_after_min_received", wait_time_after_min_received) + if targets and min_responses > len(targets): + raise ValueError( + "min_responses ({}) must be less than length of targets ({}).".format(min_responses, len(targets)) + ) + + manager = BcastTaskManager( + task=task, min_responses=min_responses, wait_time_after_min_received=wait_time_after_min_received + ) + self._schedule_task(task=task, fl_ctx=fl_ctx, manager=manager, targets=targets) + + def broadcast_and_wait( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + min_responses: int = 1, + wait_time_after_min_received: int = 0, + abort_signal: Optional[Signal] = None, + ): + """Schedule a broadcast task. This is a blocking call. + + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + + Args: + task (Task): the task to be scheduled + fl_ctx (FLContext): FLContext associated with this task + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + min_responses (int, optional): the condition to mark this task as completed because enough clients respond with submission. Defaults to 1. + wait_time_after_min_received (int, optional): a grace period for late clients to contribute their submission. 0 means no grace period. + Submission of late clients in the grace period are still collected as valid submission. Defaults to 0. + abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. + """ + self.broadcast( + task=task, + fl_ctx=fl_ctx, + targets=targets, + min_responses=min_responses, + wait_time_after_min_received=wait_time_after_min_received, + ) + self.wait_for_task(task, abort_signal) + + def broadcast_forever(self, task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None] = None): + """Schedule a broadcast task. This is a non-blocking call. + + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + This broadcast will not end. + + Args: + task (Task): the task to be scheduled + fl_ctx (FLContext): FLContext associated with this task + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + """ + _check_inputs(task=task, fl_ctx=fl_ctx, targets=targets) + manager = BcastForeverTaskManager() + self._schedule_task(task=task, fl_ctx=fl_ctx, manager=manager, targets=targets) + + def send( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order: SendOrder = SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + ): + """Schedule a single task to targets. This is a non-blocking call. + + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + + Args: + task (Task): the task to be scheduled + fl_ctx (FLContext): FLContext associated with this task + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means + clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. + task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. + + Raises: + ValueError: when task_assignment_timeout is greater than task's timeout. + TypeError: send_order is not defined in SendOrder + ValueError: targets is None or an empty list + """ + _check_inputs( + task=task, + fl_ctx=fl_ctx, + targets=targets, + ) + _check_positive_int("task_assignment_timeout", task_assignment_timeout) + if task.timeout and task_assignment_timeout and task_assignment_timeout > task.timeout: + raise ValueError( + "task_assignment_timeout ({}) needs to be less than or equal to task.timeout ({}).".format( + task_assignment_timeout, task.timeout + ) + ) + if not isinstance(send_order, SendOrder): + raise TypeError("send_order must be in Enum SendOrder, but got {}".format(type(send_order))) + + # targets must be provided + if targets is None or len(targets) == 0: + raise ValueError("Targets must be provided for send.") + + manager = SendTaskManager(task, send_order, task_assignment_timeout) + self._schedule_task( + task=task, + fl_ctx=fl_ctx, + manager=manager, + targets=targets, + ) + + def send_and_wait( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order: SendOrder = SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + abort_signal: Signal = None, + ): + """Schedule a single task to targets. This is a blocking call. + + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + + Args: + task (Task): the task to be scheduled + fl_ctx (FLContext): FLContext associated with this task + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means + clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. + task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. + abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. + + """ + self.send( + task=task, + fl_ctx=fl_ctx, + targets=targets, + send_order=send_order, + task_assignment_timeout=task_assignment_timeout, + ) + self.wait_for_task(task, abort_signal) + + def get_num_standing_tasks(self) -> int: + """Get the number of tasks that are currently standing. + + Returns: + int: length of the list of standing tasks + """ + return len(self._tasks) + + def cancel_task( + self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None + ): + """Cancel the specified task. + + Change the task completion_status, which will inform task monitor to clean up this task + + note:: + + We only mark the task as completed and leave it to the task monitor to clean up. This is to avoid potential deadlock of task_lock. + + Args: + task (Task): the task to be cancelled + completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED. + fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None. + """ + task.completion_status = completion_status + + def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): + """Cancel all standing tasks in this controller. + + Args: + completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED. + fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None. + """ + with self._task_lock: + for t in self._tasks: + t.completion_status = completion_status + + def finalize_run(self, fl_ctx: FLContext): + """Do cleanup of the coordinator implementation. + + .. attention:: + + Subclass controllers should not overwrite finalize_run. + + Args: + fl_ctx (FLContext): FLContext associated with this action + """ + self.cancel_all_tasks() # unconditionally cancel all tasks + self._all_done = True + try: + if self._task_monitor.is_alive(): + self._task_monitor.join() + except RuntimeError: + self.log_debug(fl_ctx, "unable to join monitor thread (not started?)") + + def relay( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order: SendOrder = SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + task_result_timeout: int = 0, + dynamic_targets: bool = True, + ): + """Schedule a single task to targets in one-after-another style. This is a non-blocking call. + + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + + Args: + task (Task): the task to be scheduled + fl_ctx (FLContext): FLContext associated with this task + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. + SEQUENTIAL means the order in targets is enforced. + ANY means any clients that are inside the targets and haven't received the task are eligible. Defaults to SendOrder.SEQUENTIAL. + task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. + task_result_timeout (int, optional): how long to wait for current working client to reply its result. Defaults to 0. + dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. Defaults to True. + + Raises: + ValueError: when task_assignment_timeout is greater than task's timeout + ValueError: when task_result_timeout is greater than task's timeout + TypeError: send_order is not defined in SendOrder + TypeError: when dynamic_targets is not a boolean variable + ValueError: targets is None or an empty list but dynamic_targets is False + """ + _check_inputs( + task=task, + fl_ctx=fl_ctx, + targets=targets, + ) + _check_positive_int("task_assignment_timeout", task_assignment_timeout) + _check_positive_int("task_result_timeout", task_result_timeout) + if task.timeout and task_assignment_timeout and task_assignment_timeout > task.timeout: + raise ValueError( + "task_assignment_timeout ({}) needs to be less than or equal to task.timeout ({}).".format( + task_assignment_timeout, task.timeout + ) + ) + if task.timeout and task_result_timeout and task_result_timeout > task.timeout: + raise ValueError( + "task_result_timeout ({}) needs to be less than or equal to task.timeout ({}).".format( + task_result_timeout, task.timeout + ) + ) + if not isinstance(send_order, SendOrder): + raise TypeError("send_order must be in Enum SendOrder, but got {}".format(type(send_order))) + if not isinstance(dynamic_targets, bool): + raise TypeError("dynamic_targets must be an instance of bool, but got {}".format(type(dynamic_targets))) + if targets is None and dynamic_targets is False: + raise ValueError("Need to provide targets when dynamic_targets is set to False.") + + if send_order == SendOrder.SEQUENTIAL: + manager = SequentialRelayTaskManager( + task=task, + task_assignment_timeout=task_assignment_timeout, + task_result_timeout=task_result_timeout, + dynamic_targets=dynamic_targets, + ) + else: + manager = AnyRelayTaskManager( + task=task, task_result_timeout=task_result_timeout, dynamic_targets=dynamic_targets + ) + + self._schedule_task( + task=task, + fl_ctx=fl_ctx, + manager=manager, + targets=targets, + allow_dup_targets=True, + ) + + def relay_and_wait( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order=SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + task_result_timeout: int = 0, + dynamic_targets: bool = True, + abort_signal: Optional[Signal] = None, + ): + """Schedule a single task to targets in one-after-another style. This is a blocking call. + + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + + Args: + task (Task): the task to be scheduled + fl_ctx (FLContext): FLContext associated with this task + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means + clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. + task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. + task_result_timeout (int, optional): how long to wait for current working client to reply its result. Defaults to 0. + dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. Defaults to True. + abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. + """ + self.relay( + task=task, + fl_ctx=fl_ctx, + targets=targets, + send_order=send_order, + task_assignment_timeout=task_assignment_timeout, + task_result_timeout=task_result_timeout, + dynamic_targets=dynamic_targets, + ) + self.wait_for_task(task, abort_signal) + + def _monitor_tasks(self): + while not self._all_done: + should_abort_job = self._job_policy_violated() + if not should_abort_job: + self._check_tasks() + else: + with self._engine.new_context() as fl_ctx: + self.system_panic("Aborting job due to deployment policy violation", fl_ctx) + return + time.sleep(self._task_check_period) + + def _check_tasks(self): + with self._controller_lock: + self._do_check_tasks() + + def _do_check_tasks(self): + exit_tasks = [] + with self._task_lock: + for task in self._tasks: + if task.completion_status is not None: + exit_tasks.append(task) + continue + + # check the task-specific exit condition + manager = task.props[_TASK_KEY_MANAGER] + if manager is not None: + if not isinstance(manager, TaskManager): + raise TypeError( + "manager in task must be an instance of TaskManager, but got {}".format(manager) + ) + should_exit, exit_status = manager.check_task_exit(task) + self.logger.debug("should_exit: {}, exit_status: {}".format(should_exit, exit_status)) + if should_exit: + task.completion_status = exit_status + exit_tasks.append(task) + continue + + # check if task timeout + if task.timeout and time.time() - task.schedule_time >= task.timeout: + task.completion_status = TaskCompletionStatus.TIMEOUT + exit_tasks.append(task) + continue + + # check whether clients that the task is waiting are all dead + dead_clients = self._get_task_dead_clients(task) + if dead_clients: + self.logger.info(f"client {dead_clients} is dead - set task {task.name} to TIMEOUT") + task.completion_status = TaskCompletionStatus.CLIENT_DEAD + exit_tasks.append(task) + continue + + for exit_task in exit_tasks: + exit_task.is_standing = False + self.logger.debug( + "Removing task={}, completion_status={}".format(exit_task, exit_task.completion_status) + ) + self._tasks.remove(exit_task) + for client_task in exit_task.client_tasks: + self.logger.debug("Removing client_task with id={}".format(client_task.id)) + self._client_task_map.pop(client_task.id) + + # do the task exit processing outside the lock to minimize the locking time + # and to avoid potential deadlock since the CB could schedule another task + if len(exit_tasks) <= 0: + return + + with self._engine.new_context() as fl_ctx: + for exit_task in exit_tasks: + with exit_task.cb_lock: + self.log_info( + fl_ctx, "task {} exit with status {}".format(exit_task.name, exit_task.completion_status) + ) + + if exit_task.task_done_cb is not None: + try: + exit_task.task_done_cb(task=exit_task, fl_ctx=fl_ctx) + except Exception as e: + self.log_exception( + fl_ctx, + "processing error in task_done_cb error on task {}: {}".format( + exit_task.name, secure_format_exception(e) + ), + ) + exit_task.completion_status = TaskCompletionStatus.ERROR + exit_task.exception = e + + def _get_task_dead_clients(self, task: Task): + """ + See whether the task is only waiting for response from a dead client + """ + now = time.time() + lead_time = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_CHECK_LEAD_TIME, default=30.0) + if now - task.schedule_time < lead_time: + # due to potential race conditions, we'll wait for at least 1 minute after the task + # is started before checking dead clients. + return None + + dead_clients = [] + with self._dead_clients_lock: + for target in task.targets: + ct = _get_client_task(target, task) + if ct is not None and ct.result_received_time: + # response has been received from this client + continue + + # either we have not sent the task to this client or we have not received response + # is the client already dead? + if self._client_still_alive(target): + # this client is still alive + # we let the task continue its course since we still have live clients + return None + else: + # this client is dead - remember it + dead_clients.append(target) + + return dead_clients + + @staticmethod + def _process_finished_task(task, func): + def wrap(*args, **kwargs): + if func: + func(*args, **kwargs) + task.props[_TASK_KEY_DONE] = True + + return wrap + + def wait_for_task(self, task: Task, abort_signal: Signal): + task.props[_TASK_KEY_DONE] = False + task.task_done_cb = self._process_finished_task(task=task, func=task.task_done_cb) + while True: + if task.completion_status is not None: + break + + if abort_signal and abort_signal.triggered: + self.cancel_task(task, fl_ctx=None, completion_status=TaskCompletionStatus.ABORTED) + break + + task_done = task.props[_TASK_KEY_DONE] + if task_done: + break + time.sleep(self._task_check_period) + + def _job_policy_violated(self): + if not self._engine: + return False + + with self._engine.new_context() as fl_ctx: + clients = self._engine.get_clients() + with self._dead_clients_lock: + alive_clients = [] + dead_clients = [] + + for client in clients: + if self._client_still_alive(client.name): + alive_clients.append(client.name) + else: + dead_clients.append(client.name) + + if not dead_clients: + return False + + if not alive_clients: + self.log_error(fl_ctx, f"All clients are dead: {dead_clients}") + return True + + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + job = job_from_meta(job_meta) + if len(alive_clients) < job.min_sites: + self.log_error(fl_ctx, f"Alive clients {len(alive_clients)} < required min {job.min_sites}") + return True + + # check required clients: + if dead_clients and job.required_sites: + dead_required_clients = [c for c in dead_clients if c in job.required_sites] + if dead_required_clients: + self.log_error(fl_ctx, f"Required client(s) dead: {dead_required_clients}") + return True + return False + + def _client_still_alive(self, client_name): + now = time.time() + report_time = self._dead_client_reports.get(client_name, None) + grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=30.0) + + if not report_time: + # this client is still alive + return True + elif now - report_time < grace_period: + # this report is still fresh - consider the client to be still alive + return True + + return False diff --git a/nvflare/apis/responder.py b/nvflare/apis/responder.py index 6072670886..21dc51a257 100644 --- a/nvflare/apis/responder.py +++ b/nvflare/apis/responder.py @@ -15,7 +15,6 @@ from abc import ABC, abstractmethod from typing import Tuple -from nvflare.apis.signal import Signal from .client import Client from .fl_component import FLComponent @@ -93,19 +92,6 @@ def initialize_run(self, fl_ctx: FLContext): """ pass - @abstractmethod - def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): - """This is the control logic for the RUN. - - NOTE: this is running in a separate thread, and its life is the duration of the RUN. - - Args: - fl_ctx: the FL context - abort_signal: the abort signal. If triggered, this method stops waiting and returns to the caller. - - """ - pass - def finalize_run(self, fl_ctx: FLContext): """Called when a new RUN is finished. diff --git a/nvflare/apis/wf_comm_spec.py b/nvflare/apis/wf_comm_spec.py new file mode 100644 index 0000000000..bb313e5849 --- /dev/null +++ b/nvflare/apis/wf_comm_spec.py @@ -0,0 +1,213 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from typing import List, Union + +from nvflare.apis.client import Client +from nvflare.apis.controller_spec import SendOrder, Task +from nvflare.apis.fl_context import FLContext +from nvflare.apis.signal import Signal + + +class WFCommSpec(ABC): + def broadcast( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + min_responses: int = 0, + wait_time_after_min_received: int = 0, + ): + """Schedule to broadcast the task to specified targets. + + This is a non-blocking call. + + The task is standing until one of the following conditions comes true: + - if timeout is specified (> 0), and the task has been standing for more than the specified time + - the controller has received the specified min_responses results for this task, and all target clients + are done. + - the controller has received the specified min_responses results for this task, and has waited + for wait_time_after_min_received. + + While the task is standing: + - Before sending the task to a client, the before_task_sent CB (if specified) is called; + - When a result is received from a client, the result_received CB (if specified) is called; + + After the task is done, the task_done CB (if specified) is called: + - If result_received CB is specified, the 'result' in the ClientTask of each + client is produced by the result_received CB; + - Otherwise, the 'result' contains the original result submitted by the clients; + + NOTE: if the targets is None, the actual broadcast target clients will be dynamic, because the clients + could join/disconnect at any moment. While the task is standing, any client that joins automatically + becomes a target for this broadcast. + + Args: + task: the task to be sent + fl_ctx: the FL context + targets: list of destination clients. None means all clients are determined dynamically; + min_responses: the min number of responses expected. If == 0, must get responses from + all clients that the task has been sent to; + wait_time_after_min_received: how long (secs) to wait after the min_responses is received. + If == 0, end the task immediately after the min responses are received; + + """ + pass + + def broadcast_and_wait( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + min_responses: int = 0, + wait_time_after_min_received: int = 0, + abort_signal: Signal = None, + ): + """This is the blocking version of the 'broadcast' method. + + First, the task is scheduled for broadcast (see the broadcast method); + It then waits until the task is completed. + + Args: + task: the task to be sent + fl_ctx: the FL context + targets: list of destination clients. None means all clients are determined dynamically. + min_responses: the min number of responses expected. If == 0, must get responses from + all clients that the task has been sent to; + wait_time_after_min_received: how long (secs) to wait after the min_responses is received. + If == 0, end the task immediately after the min responses are received; + abort_signal: the abort signal. If triggered, this method stops waiting and returns to the caller. + + """ + pass + + def broadcast_forever( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + ): + """Schedule a broadcast task that never ends until timeout or explicitly cancelled. + + All clients will get the task every time it asks for a new task. + This is a non-blocking call. + + NOTE: you can change the content of the task in the before_task_sent function. + + Args: + task: the task to be sent + fl_ctx: the FL context + targets: list of destination clients. None means all clients are determined dynamically. + + """ + pass + + def send( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order: SendOrder = SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + ): + """Schedule to send the task to a single target client. + + This is a non-blocking call. + + In ANY order, the target client is the first target that asks for task. + In SEQUENTIAL order, the controller will try its best to send the task to the first client + in the targets list. If can't, it will try the next target, and so on. + + NOTE: if the 'targets' is None, the actual target clients will be dynamic, because the clients + could join/disconnect at any moment. While the task is standing, any client that joins automatically + becomes a target for this task. + + If the send_order is SEQUENTIAL, the targets must be a non-empty list of client names. + + Args: + task: the task to be sent + fl_ctx: the FL context + targets: list of candidate target clients. + send_order: how to choose the client to send the task. + task_assignment_timeout: in SEQUENTIAL order, this is the wait time for trying a target client, before trying next target. + + """ + pass + + def send_and_wait( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order: SendOrder = SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + abort_signal: Signal = None, + ): + """This is the blocking version of the 'send' method. + + First, the task is scheduled for send (see the 'send' method); + It then waits until the task is completed and returns the task completion status and collected result. + + Args: + task: the task to be performed by each client + fl_ctx: the FL context for scheduling the task + targets: list of clients. If None, all clients. + send_order: how to choose the next client + task_assignment_timeout: how long to wait for the expected client to get assigned + before assigning to next client. + abort_signal: the abort signal. If triggered, this method stops waiting and returns to the caller. + + """ + pass + + def relay( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order: SendOrder = SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + task_result_timeout: int = 0, + dynamic_targets: bool = True, + ): + """Schedules a task to be done sequentially by the clients in the targets list. This is a non-blocking call. + + Args: + task: the task to be performed by each client + fl_ctx: the FL context for scheduling the task + targets: list of clients. If None, all clients. + send_order: how to choose the next client + task_assignment_timeout: how long to wait for the expected client to get assigned + before assigning to next client. + task_result_timeout: how long to wait for result from the assigned client before giving up. + dynamic_targets: whether to dynamically grow the target list. If True, then the target list is + expanded dynamically when a new client joins. + + """ + pass + + def relay_and_wait( + self, + task: Task, + fl_ctx: FLContext, + targets: Union[List[Client], List[str], None] = None, + send_order=SendOrder.SEQUENTIAL, + task_assignment_timeout: int = 0, + task_result_timeout: int = 0, + dynamic_targets: bool = True, + abort_signal: Signal = None, + ): + """This is the blocking version of 'relay'.""" + pass diff --git a/nvflare/app_common/ccwf/client_ctl.py b/nvflare/app_common/ccwf/client_ctl.py index f19011b93c..cb1c00c48c 100644 --- a/nvflare/app_common/ccwf/client_ctl.py +++ b/nvflare/app_common/ccwf/client_ctl.py @@ -192,7 +192,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): return report = self._get_status_report() if not report: - self.log_info(fl_ctx, "nothing to report this time") + self.log_debug(fl_ctx, "nothing to report this time") return self._add_status_report(report, fl_ctx) self.last_status_report_time = report.timestamp diff --git a/nvflare/app_common/ccwf/server_ctl.py b/nvflare/app_common/ccwf/server_ctl.py index 5ec2e868a1..808acf9d33 100644 --- a/nvflare/app_common/ccwf/server_ctl.py +++ b/nvflare/app_common/ccwf/server_ctl.py @@ -445,7 +445,7 @@ def _update_client_status(self, fl_ctx: FLContext): # see whether status is available reports = peer_ctx.get_prop(Constant.STATUS_REPORTS) if not reports: - self.log_info(fl_ctx, f"no status report from client {client_name}") + self.log_debug(fl_ctx, f"no status report from client {client_name}") return my_report = reports.get(self.workflow_id) diff --git a/nvflare/private/fed/server/server_json_config.py b/nvflare/private/fed/server/server_json_config.py index 5685a2e456..c0ffac7c94 100644 --- a/nvflare/private/fed/server/server_json_config.py +++ b/nvflare/private/fed/server/server_json_config.py @@ -16,7 +16,7 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import SystemConfigs, SystemVarName -from nvflare.apis.responder import Responder +from nvflare.apis.impl.controller import Controller from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node @@ -30,7 +30,7 @@ class WorkFlow: - def __init__(self, id, responder: Responder): + def __init__(self, id, controller: Controller): """Workflow is a responder with ID. Args: @@ -38,7 +38,7 @@ def __init__(self, id, responder: Responder): responder (Responder): A responder """ self.id = id - self.responder = responder + self.controller = controller class ServerJsonConfigurator(FedJsonConfigurator): @@ -125,10 +125,8 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node): if re.search(r"^workflows\.#[0-9]+$", path): workflow = self.authorize_and_build_component(element, config_ctx, node) - if not isinstance(workflow, Responder): - raise ConfigError( - '"workflow" must be a Responder or Controller object, but got {}'.format(type(workflow)) - ) + if not isinstance(workflow, Controller): + raise ConfigError('"workflow" must be a Controller object, but got {}'.format(type(workflow))) cid = element.get("id", None) if not cid: diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 75bb9b3d7e..224d2f6304 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -20,6 +20,7 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FilterKey, FLContextKey, ReservedKey, ReservedTopic, ReturnCode from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.wf_comm_server import WFCommServer from nvflare.apis.server_engine_spec import ServerEngineSpec from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply from nvflare.apis.signal import Signal @@ -122,12 +123,15 @@ def _execute_run(self): wf = self.config.workflows[self.current_wf_index] try: with self.engine.new_context() as fl_ctx: - self.log_info(fl_ctx, "starting workflow {} ({}) ...".format(wf.id, type(wf.responder))) + self.log_info(fl_ctx, "starting workflow {} ({}) ...".format(wf.id, type(wf.controller))) fl_ctx.set_prop(FLContextKey.WORKFLOW, wf.id, sticky=True) - wf.responder.initialize_run(fl_ctx) - self.log_info(fl_ctx, "Workflow {} ({}) started".format(wf.id, type(wf.responder))) + wf.controller.set_communicator(WFCommServer(), fl_ctx) + wf.controller.communicator.initialize_run(fl_ctx) + wf.controller.start_controller(fl_ctx) + + self.log_info(fl_ctx, "Workflow {} ({}) started".format(wf.id, type(wf.controller))) self.log_debug(fl_ctx, "firing event EventType.START_WORKFLOW") self.fire_event(EventType.START_WORKFLOW, fl_ctx) @@ -137,7 +141,7 @@ def _execute_run(self): self.current_wf = wf with self.engine.new_context() as fl_ctx: - wf.responder.control_flow(self.abort_signal, fl_ctx) + wf.controller.control_flow(self.abort_signal, fl_ctx) except Exception as e: with self.engine.new_context() as fl_ctx: self.log_exception(fl_ctx, "Exception in workflow {}: {}".format(wf.id, secure_format_exception(e))) @@ -155,7 +159,8 @@ def _execute_run(self): self.log_info(fl_ctx, f"Workflow: {wf.id} finalizing ...") try: - wf.responder.finalize_run(fl_ctx) + wf.controller.stop_controller(fl_ctx) + wf.controller.communicator.finalize_run(fl_ctx) except Exception as e: self.log_exception( fl_ctx, "Error finalizing workflow {}: {}".format(wf.id, secure_format_exception(e)) @@ -304,7 +309,7 @@ def process_task_request(self, client: Client, fl_ctx: FLContext) -> (str, str, ) with self.wf_lock: if self.current_wf: - self.current_wf.responder.handle_exception(task_id, fl_ctx) + self.current_wf.controller.communicator.handle_exception(task_id, fl_ctx) return self._task_try_again() self.log_debug(fl_ctx, "firing event EventType.AFTER_TASK_DATA_FILTER") @@ -330,7 +335,9 @@ def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005): self.log_debug(fl_ctx, "no current workflow - asked client to try again later") return "", "", None - task_name, task_id, task_data = self.current_wf.responder.process_task_request(client, fl_ctx) + task_name, task_id, task_data = self.current_wf.controller.communicator.process_task_request( + client, fl_ctx + ) if task_name and task_name != SpecialTaskName.TRY_AGAIN: if task_data: @@ -371,7 +378,7 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): if self.current_wf is None: return - self.current_wf.responder.handle_dead_job(client_name=client_name, fl_ctx=fl_ctx) + self.current_wf.controller.communicator.handle_dead_job(client_name=client_name, fl_ctx=fl_ctx) except Exception as e: self.log_exception( fl_ctx, f"Error processing dead job by workflow {self.current_wf.id}: {secure_format_exception(e)}" @@ -475,7 +482,7 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul self.log_debug(fl_ctx, "firing event EventType.BEFORE_PROCESS_SUBMISSION") self.fire_event(EventType.BEFORE_PROCESS_SUBMISSION, fl_ctx) - self.current_wf.responder.process_submission( + self.current_wf.controller.communicator.process_submission( client=client, task_name=task_name, task_id=task_id, result=result, fl_ctx=fl_ctx ) self.log_info(fl_ctx, "finished processing client result by {}".format(self.current_wf.id)) @@ -501,11 +508,11 @@ def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) self.log_debug(fl_ctx, f"received task_check on task {task_id}") with self.wf_lock: - if self.current_wf is None or self.current_wf.responder is None: + if self.current_wf is None or self.current_wf.controller is None: self.log_info(fl_ctx, "no current workflow - dropped task_check.") return make_reply(ReturnCode.TASK_UNKNOWN) - task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx) + task = self.current_wf.controller.communicator.process_task_check(task_id=task_id, fl_ctx=fl_ctx) if task: self.log_debug(fl_ctx, f"task {task_id} is still good") return make_reply(ReturnCode.OK) From 16b31193d294302765d72edccdddfc694990ae98 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Thu, 7 Mar 2024 17:03:08 -0800 Subject: [PATCH 2/3] fix unit tests, clean up --- nvflare/apis/event_type.py | 1 + nvflare/apis/impl/controller.py | 36 ++-- nvflare/apis/impl/wf_comm_server.py | 5 +- nvflare/apis/responder.py | 1 - nvflare/app_common/ccwf/cse_server_ctl.py | 2 +- nvflare/app_common/ccwf/server_ctl.py | 10 +- .../workflows/broadcast_and_process.py | 3 +- .../workflows/broadcast_operator.py | 2 +- nvflare/app_common/workflows/cyclic_ctl.py | 3 +- .../app_common/workflows/model_controller.py | 4 +- .../workflows/scatter_and_gather.py | 4 +- .../app_common/workflows/splitnn_workflow.py | 3 +- .../workflows/statistics_controller.py | 3 +- nvflare/app_common/xgb/controller.py | 3 +- nvflare/private/fed/server/server_runner.py | 1 + .../autofedrl/autofedrl_scatter_and_gather.py | 2 +- .../app/custom/custom_controller.py | 3 +- tests/unit_test/apis/impl/controller_test.py | 173 +++++++++--------- .../workflow/mock_statistics_controller.py | 2 +- 19 files changed, 144 insertions(+), 117 deletions(-) diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 942da170c5..929c942e4c 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -37,6 +37,7 @@ class EventType(object): BEFORE_PULL_TASK = "_before_pull_task" AFTER_PULL_TASK = "_after_pull_task" + BEFORE_PROCESS_TASK = "_before_process_task_request" BEFORE_PROCESS_SUBMISSION = "_before_process_submission" AFTER_PROCESS_SUBMISSION = "_after_process_submission" diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 1297f35587..7f9f15bd5e 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -15,28 +15,33 @@ from typing import List, Optional, Union from nvflare.apis.client import Client -from nvflare.apis.controller_spec import ClientTask, ControllerSpec, SendOrder, Task, TaskCompletionStatus +from nvflare.apis.controller_spec import ControllerSpec, SendOrder, Task, TaskCompletionStatus from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.wf_comm_server import WFCommServer from nvflare.apis.signal import Signal from nvflare.apis.wf_comm_spec import WFCommSpec class Controller(FLComponent, ControllerSpec, ABC): def __init__(self, task_check_period=0.2): - """Manage life cycles of tasks and their destinations. + """Controller logic for tasks and their destinations. + + Must set_communicator() to access communication related function implementations. Args: - task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.2. + task_check_period (float, optional): interval for checking status of tasks. Applicable for WFCommServer. Defaults to 0.2. """ super().__init__() self._task_check_period = task_check_period self.communicator = None def set_communicator(self, communicator: WFCommSpec, fl_ctx: FLContext): - communicator.task_check_period = self._task_check_period + if not isinstance(communicator, WFCommSpec): + raise TypeError(f"communicator must be an instance of WFCommSpec, but got {type(communicator)}") + self.communicator = communicator + self.communicator.controller = self + self.communicator.task_check_period = self._task_check_period engine = fl_ctx.get_engine() if not engine: self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx) @@ -128,18 +133,21 @@ def relay_and_wait( ) def get_num_standing_tasks(self) -> int: - if not isinstance(self.communicator, WFCommServer): - raise NotImplementedError - return self.communicator.get_num_standing_tasks() + try: + return self.communicator.get_num_standing_tasks() + except: + raise NotImplementedError(f"{self.communicator} does not support this function") def cancel_task( self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None ): - if not isinstance(self.communicator, WFCommServer): - raise NotImplementedError - self.communicator.cancel_task(task, completion_status, fl_ctx) + try: + self.communicator.cancel_task(task, completion_status, fl_ctx) + except: + raise NotImplementedError(f"{self.communicator} does not support this function") def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): - if not isinstance(self.communicator, WFCommServer): - raise NotImplementedError - self.communicator.cancel_all_tasks(completion_status, fl_ctx) + try: + self.communicator.cancel_all_tasks(completion_status, fl_ctx) + except: + raise NotImplementedError(f"{self.communicator} does not support this function") diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py index b15f555c1a..a1ffc1f336 100644 --- a/nvflare/apis/impl/wf_comm_server.py +++ b/nvflare/apis/impl/wf_comm_server.py @@ -86,6 +86,7 @@ def __init__(self, task_check_period=0.2): task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.2. """ super().__init__() + self.controller = None self._engine = None self._tasks = [] # list of standing tasks self._client_task_map = {} # client_task_id => client_task @@ -343,6 +344,8 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): if not self._dead_client_reports.get(client_name): self._dead_client_reports[client_name] = time.time() + self.controller.handle_dead_job(client_name, fl_ctx) + def process_task_check(self, task_id: str, fl_ctx: FLContext): with self._task_lock: # task_id is the uuid associated with the client_task @@ -397,7 +400,7 @@ def _do_process_submission( if client_task is None: # cannot find a standing task for the submission self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id)) - self.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx) + self.controller.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx) return task = client_task.task diff --git a/nvflare/apis/responder.py b/nvflare/apis/responder.py index 21dc51a257..9935c430ac 100644 --- a/nvflare/apis/responder.py +++ b/nvflare/apis/responder.py @@ -15,7 +15,6 @@ from abc import ABC, abstractmethod from typing import Tuple - from .client import Client from .fl_component import FLComponent from .fl_context import FLContext diff --git a/nvflare/app_common/ccwf/cse_server_ctl.py b/nvflare/app_common/ccwf/cse_server_ctl.py index b55dc41800..f95e5fdc05 100644 --- a/nvflare/app_common/ccwf/cse_server_ctl.py +++ b/nvflare/app_common/ccwf/cse_server_ctl.py @@ -14,9 +14,9 @@ import os +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.dxo import from_shareable from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Task from nvflare.apis.shareable import ReturnCode, Shareable from nvflare.apis.signal import Signal from nvflare.apis.workspace import Workspace diff --git a/nvflare/app_common/ccwf/server_ctl.py b/nvflare/app_common/ccwf/server_ctl.py index 808acf9d33..55f31643cc 100644 --- a/nvflare/app_common/ccwf/server_ctl.py +++ b/nvflare/app_common/ccwf/server_ctl.py @@ -16,9 +16,11 @@ from datetime import datetime from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task +from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import ReturnCode, Shareable from nvflare.apis.signal import Signal from nvflare.app_common.app_constant import AppConstants @@ -341,9 +343,9 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): self.log_info(fl_ctx, f"Workflow {self.workflow_id} done!") - def process_task_request(self, client: Client, fl_ctx: FLContext): - self._update_client_status(fl_ctx) - return super().process_task_request(client, fl_ctx) + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.BEFORE_PROCESS_TASK: + self._update_client_status(fl_ctx) def process_config_reply(self, client_name: str, reply: Shareable, fl_ctx: FLContext) -> bool: return True diff --git a/nvflare/app_common/workflows/broadcast_and_process.py b/nvflare/app_common/workflows/broadcast_and_process.py index fd2a8216aa..d015817c30 100644 --- a/nvflare/app_common/workflows/broadcast_and_process.py +++ b/nvflare/app_common/workflows/broadcast_and_process.py @@ -15,8 +15,9 @@ from typing import Union from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.response_processor import ResponseProcessor diff --git a/nvflare/app_common/workflows/broadcast_operator.py b/nvflare/app_common/workflows/broadcast_operator.py index 8827e6209f..ce548cd95e 100644 --- a/nvflare/app_common/workflows/broadcast_operator.py +++ b/nvflare/app_common/workflows/broadcast_operator.py @@ -16,11 +16,11 @@ from typing import Dict, List, Optional, Union from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.dxo import DXO, from_shareable from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Task from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.workflows.error_handling_controller import ErrorHandlingController diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index 754e1b06b6..65f8cb9195 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -16,9 +16,10 @@ import random from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index 0b65f07539..cd6b565241 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -16,10 +16,10 @@ from typing import List, Union from nvflare.apis.client import Client -from nvflare.apis.controller_spec import OperatorMethod, TaskOperatorKey +from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.fl_model import FLModel, ParamsType diff --git a/nvflare/app_common/workflows/scatter_and_gather.py b/nvflare/app_common/workflows/scatter_and_gather.py index e663d81e5a..2255c6a160 100644 --- a/nvflare/app_common/workflows/scatter_and_gather.py +++ b/nvflare/app_common/workflows/scatter_and_gather.py @@ -15,10 +15,10 @@ from typing import Any from nvflare.apis.client import Client -from nvflare.apis.controller_spec import OperatorMethod, TaskOperatorKey +from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.aggregator import Aggregator diff --git a/nvflare/app_common/workflows/splitnn_workflow.py b/nvflare/app_common/workflows/splitnn_workflow.py index ae3e8a17a8..b16df0e837 100644 --- a/nvflare/app_common/workflows/splitnn_workflow.py +++ b/nvflare/app_common/workflows/splitnn_workflow.py @@ -13,9 +13,10 @@ # limitations under the License. from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor diff --git a/nvflare/app_common/workflows/statistics_controller.py b/nvflare/app_common/workflows/statistics_controller.py index 1c83bba99d..bc59aefbd8 100644 --- a/nvflare/app_common/workflows/statistics_controller.py +++ b/nvflare/app_common/workflows/statistics_controller.py @@ -16,11 +16,12 @@ from typing import Callable, Dict, List, Optional from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.dxo import from_shareable from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.statistics_spec import Bin, Histogram, StatisticConfig diff --git a/nvflare/app_common/xgb/controller.py b/nvflare/app_common/xgb/controller.py index 83bd6f90a1..5e96f95aa2 100644 --- a/nvflare/app_common/xgb/controller.py +++ b/nvflare/app_common/xgb/controller.py @@ -15,8 +15,9 @@ import time from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import ReturnCode, Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.xgb.adaptors.adaptor import XGBServerAdaptor diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 224d2f6304..7f5a028697 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -335,6 +335,7 @@ def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005): self.log_debug(fl_ctx, "no current workflow - asked client to try again later") return "", "", None + self.fire_event(EventType.BEFORE_PROCESS_TASK, fl_ctx) task_name, task_id, task_data = self.current_wf.controller.communicator.process_task_request( client, fl_ctx ) diff --git a/research/auto-fed-rl/src/autofedrl/autofedrl_scatter_and_gather.py b/research/auto-fed-rl/src/autofedrl/autofedrl_scatter_and_gather.py index 20a44f064b..6d2e3bd8a3 100644 --- a/research/auto-fed-rl/src/autofedrl/autofedrl_scatter_and_gather.py +++ b/research/auto-fed-rl/src/autofedrl/autofedrl_scatter_and_gather.py @@ -14,10 +14,10 @@ import traceback +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Task from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.aggregator import Aggregator diff --git a/tests/integration_test/data/apps/tb_streaming/app/custom/custom_controller.py b/tests/integration_test/data/apps/tb_streaming/app/custom/custom_controller.py index dd36aad0f3..9a585eb81c 100755 --- a/tests/integration_test/data/apps/tb_streaming/app/custom/custom_controller.py +++ b/tests/integration_test/data/apps/tb_streaming/app/custom/custom_controller.py @@ -13,8 +13,9 @@ # limitations under the License. from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor diff --git a/tests/unit_test/apis/impl/controller_test.py b/tests/unit_test/apis/impl/controller_test.py index 2cfc656d44..439199e267 100644 --- a/tests/unit_test/apis/impl/controller_test.py +++ b/tests/unit_test/apis/impl/controller_test.py @@ -26,6 +26,7 @@ from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus from nvflare.apis.fl_context import FLContext, FLContextManager from nvflare.apis.impl.controller import Controller +from nvflare.apis.impl.wf_comm_server import WFCommServer from nvflare.apis.server_engine_spec import ServerEngineSpec from nvflare.apis.shareable import ReservedHeaderKey, Shareable from nvflare.apis.signal import Signal @@ -123,7 +124,9 @@ def _setup_system(num_clients=1): controller = DummyController() fl_ctx = mock_server_engine.new_context() - controller.initialize_run(fl_ctx=fl_ctx) + communicator = WFCommServer() + controller.set_communicator(communicator, fl_ctx) + controller.communicator.initialize_run(fl_ctx=fl_ctx) return controller, mock_server_engine, fl_ctx, clients_list @@ -139,7 +142,7 @@ def setup_system(num_of_clients=1): @staticmethod def teardown_system(controller, fl_ctx): - controller.finalize_run(fl_ctx=fl_ctx) + controller.communicator.finalize_run(fl_ctx=fl_ctx) class TestTaskManagement(TestController): @@ -230,7 +233,7 @@ def test_check_task_remove_cancelled_tasks(self, method, num_of_start_tasks, num for i in range(num_of_cancel_tasks): controller.cancel_task(task=all_tasks[i], fl_ctx=fl_ctx) assert all_tasks[i].completion_status == TaskCompletionStatus.CANCELLED - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == (num_of_start_tasks - num_of_cancel_tasks) controller.cancel_all_tasks() for thread in all_threads: @@ -256,11 +259,11 @@ def test_client_request_after_cancel_task(self, method, num_client_requests): get_ready(launch_thread) controller.cancel_task(task) for i in range(num_client_requests): - _, task_id, data = controller.process_task_request(client, fl_ctx) + _, task_id, data = controller.communicator.process_task_request(client, fl_ctx) # check if task_id is empty means this task is not assigned assert task_id == "" assert data is None - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.CANCELLED launch_thread.join() @@ -282,18 +285,18 @@ def test_client_submit_result_after_cancel_task(self, method): }, ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) controller.cancel_task(task) assert task.completion_status == TaskCompletionStatus.CANCELLED time.sleep(1) - print(controller._tasks) + print(controller.communicator._tasks) # in here we make up client results: result = Shareable() result["result"] = "result" with pytest.raises(RuntimeError, match="Unknown task: __test_task from client __test_client0."): - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=result ) @@ -531,7 +534,7 @@ def test_process_submission_invalid_input(self, method, kwargs, error, msg): controller, fl_ctx, clients = self.setup_system() with pytest.raises(error, match=msg): - controller.process_submission(**kwargs) + controller.communicator.process_submission(**kwargs) self.teardown_system(controller, fl_ctx) @@ -573,13 +576,13 @@ def clients_pull_and_submit_result(controller, ctx, clients, task_name): client_task_ids = [] num_of_clients = len(clients) for i in range(num_of_clients): - task_name_out, client_task_id, data = controller.process_task_request(clients[i], ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[i], ctx) assert task_name_out == task_name client_task_ids.append(client_task_id) for client, client_task_id in zip(clients, client_task_ids): data = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=client, task_name=task_name, task_id=client_task_id, fl_ctx=ctx, result=data ) @@ -606,7 +609,7 @@ def before_task_sent_cb(client_task: ClientTask, **kwargs): ) get_ready(launch_thread) - task_name_out, _, data = controller.process_task_request(client, fl_ctx) + task_name_out, _, data = controller.communicator.process_task_request(client, fl_ctx) assert data["_test_data"] == client_name controller.cancel_task(task) @@ -636,13 +639,13 @@ def result_received_cb(client_task: ClientTask, **kwargs): }, ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) - controller.process_submission( + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=data ) assert task.last_client_task_map[client_name].result["_test_data"] == client_name - controller._check_tasks() + controller.communicator._check_tasks() assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -670,7 +673,7 @@ def test_task_done_cb(self, method, num_clients, task_name, input_data, cb, expe client_task_ids = len(clients) * [None] for i, client in enumerate(clients): - task_name_out, client_task_ids[i], _ = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_ids[i], _ = controller.communicator.process_task_request(client, fl_ctx) if task_name_out == "": client_task_ids[i] = None @@ -682,17 +685,17 @@ def test_task_done_cb(self, method, num_clients, task_name, input_data, cb, expe for client, client_task_id in zip(clients, client_task_ids): if client_task_id is not None: if task_complete == "normal": - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=result ) if task_complete == "timeout": time.sleep(timeout) - controller._check_tasks() + controller.communicator._check_tasks() assert task.completion_status == TaskCompletionStatus.TIMEOUT elif task_complete == "cancel": controller.cancel_task(task) assert task.completion_status == TaskCompletionStatus.CANCELLED - controller._check_tasks() + controller.communicator._check_tasks() assert task.props[task_name] == expected assert controller.get_num_standing_tasks() == 0 launch_thread.join() @@ -718,7 +721,7 @@ def before_task_sent_cb(client_task: ClientTask, **kwargs): ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -748,16 +751,16 @@ def result_received_cb(client_task: ClientTask, **kwargs): ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client1, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client1, fl_ctx) result = Shareable() result["__result"] = "__test_result" - controller.process_submission( + controller.communicator.process_submission( client=client1, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=result ) assert task.last_client_task_map["__test_client0"].result == result - task_name_out, client_task_id, data = controller.process_task_request(client2, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client2, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -803,12 +806,12 @@ def before_task_sent_cb(client_task: ClientTask, fl_ctx: FLContext): task_name_out = "" while task_name_out == "": - task_name_out, _, _ = controller.process_task_request(client, ctx) + task_name_out, _, _ = controller.communicator.process_task_request(client, ctx) time.sleep(0.1) assert task_name_out == "__test_task" new_task_name_out = "" while new_task_name_out == "": - new_task_name_out, _, _ = controller.process_task_request(client, ctx) + new_task_name_out, _, _ = controller.communicator.process_task_request(client, ctx) time.sleep(0.1) assert new_task_name_out == "__new_test_task" @@ -857,18 +860,18 @@ def result_received_cb(client_task: ClientTask, fl_ctx: FLContext): client_task_id = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(client, ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, ctx) time.sleep(0.1) assert task_name_out == "__test_task" - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=ctx, result=data ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 new_task_name_out = "" while new_task_name_out == "": - new_task_name_out, _, _ = controller.process_task_request(client, ctx) + new_task_name_out, _, _ = controller.communicator.process_task_request(client, ctx) time.sleep(0.1) assert new_task_name_out == "__new_test_task" launch_thread.join() @@ -911,14 +914,14 @@ def result_received_cb(client_task: ClientTask, fl_ctx: FLContext): launch_thread.start() clients_pull_and_submit_result(controller=controller, ctx=ctx, clients=clients, task_name="__test_task") - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == num_of_clients for i in range(num_of_clients): clients_pull_and_submit_result( controller=controller, ctx=ctx, clients=clients, task_name=f"__new_test_task_{clients[i].name}" ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == num_of_clients - (i + 1) launch_thread.join() @@ -931,7 +934,7 @@ def test_process_submission_invalid_task(self, task_name, client_name): controller, fl_ctx, clients = self.setup_system() client = clients[0] with pytest.raises(RuntimeError, match=f"Unknown task: {task_name} from client {client_name}."): - controller.process_submission( + controller.communicator.process_submission( client=client, task_name=task_name, task_id=str(uuid.uuid4()), fl_ctx=FLContext(), result=Shareable() ) self.teardown_system(controller, fl_ctx) @@ -957,7 +960,7 @@ def test_process_task_request_client_request_multiple_times(self, method, num_cl get_ready(launch_thread) for i in range(num_client_requests): - task_name_out, _, data = controller.process_task_request(client, fl_ctx) + task_name_out, _, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map["__test_client0"].task_send_count == num_client_requests @@ -982,12 +985,12 @@ def test_process_submission(self, method): ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) # in here we make up client results: result = Shareable() result["result"] = "result" - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=result ) assert task.last_client_task_map["__test_client0"].result == result @@ -1039,7 +1042,7 @@ def test_cancel_task(self, method): assert controller.get_num_standing_tasks() == 1 controller.cancel_task(task=task) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.CANCELLED launch_thread.join() @@ -1076,7 +1079,7 @@ def test_cancel_all_tasks(self, method): assert controller.get_num_standing_tasks() == 2 controller.cancel_all_tasks() - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.CANCELLED assert task1.completion_status == TaskCompletionStatus.CANCELLED @@ -1111,19 +1114,19 @@ def test_client_receive_only_one_task(self, method, num_of_clients): client_task_id = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map[client.name].task_send_count == 1 assert controller.get_num_standing_tasks() == 1 - _, next_client_task_id, _ = controller.process_task_request(client, fl_ctx) + _, next_client_task_id, _ = controller.communicator.process_task_request(client, fl_ctx) assert next_client_task_id == client_task_id assert task.last_client_task_map[client.name].task_send_count == 2 result = Shareable() result["result"] = "result" - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, @@ -1132,7 +1135,7 @@ def test_client_receive_only_one_task(self, method, num_of_clients): ) assert task.last_client_task_map[client.name].result == result - controller._check_tasks() + controller.communicator._check_tasks() assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -1160,7 +1163,7 @@ def test_only_client_in_target_will_get_task(self, method, num_of_clients): task_name_out = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -1168,7 +1171,7 @@ def test_only_client_in_target_will_get_task(self, method, num_of_clients): assert controller.get_num_standing_tasks() == 1 for client in clients[1:]: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -1199,19 +1202,19 @@ def test_task_only_exit_when_min_responses_received(self, method, min_responses) client_task_ids = [] for client in clients: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) client_task_ids.append(client_task_id) assert task_name_out == "__test_task" for client, client_task_id in zip(clients, client_task_ids): result = Shareable() - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, result=result, fl_ctx=fl_ctx ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() @@ -1244,17 +1247,17 @@ def test_task_exit_quickly_when_all_responses_received(self, method, min_respons client_task_ids = [] for client in clients: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) client_task_ids.append(client_task_id) assert task_name_out == "__test_task" for client, client_task_id in zip(clients, client_task_ids): result = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, result=result, fl_ctx=fl_ctx ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() @@ -1284,21 +1287,21 @@ def test_min_resp_is_zero_task_only_exit_when_all_client_task_done(self, method, client_task_ids = [] for client in clients: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) client_task_ids.append(client_task_id) assert task_name_out == "__test_task" - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 for client, client_task_id in zip(clients, client_task_ids): - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 result = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, result=result, fl_ctx=fl_ctx ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() @@ -1604,7 +1607,7 @@ def test_only_client_in_target_will_get_task(self, method, send_order): task_name_out = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -1612,7 +1615,7 @@ def test_only_client_in_target_will_get_task(self, method, send_order): assert controller.get_num_standing_tasks() == 1 for client in clients[1:]: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -1650,7 +1653,7 @@ def test_task_assignment_timeout_sequential_order_only_client_in_target_will_get time.sleep(task_assignment_timeout + 1) for client in clients[1:]: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -1695,7 +1698,7 @@ def test_process_task_request( assert controller.get_num_standing_tasks() == 1 time.sleep(time_before_first_request) - task_name, task_id, data = controller.process_task_request(client=request_client, fl_ctx=fl_ctx) + task_name, task_id, data = controller.communicator.process_task_request(client=request_client, fl_ctx=fl_ctx) client_get_a_task = True if task_name == "__test_task" else False assert client_get_a_task == expected_to_get_task @@ -1729,7 +1732,7 @@ def test_sequential_sequence(self, method, targets): client_tasks_and_results = {} for c in targets: - task_name, task_id, data = controller.process_task_request(client=c, fl_ctx=fl_ctx) + task_name, task_id, data = controller.communicator.process_task_request(client=c, fl_ctx=fl_ctx) if task_name != "": client_result = Shareable() client_result["result"] = f"{c.name}" @@ -1740,7 +1743,7 @@ def test_sequential_sequence(self, method, targets): for task_id in client_tasks_and_results.keys(): c, task_name, client_result = client_tasks_and_results[task_id] task.data["result"] += client_result["result"] - controller.process_submission( + controller.communicator.process_submission( client=c, task_name=task_name, task_id=task_id, result=client_result, fl_ctx=fl_ctx ) assert task.last_client_task_map[c.name].result == client_result @@ -1797,22 +1800,24 @@ def test_process_request_and_submission_with_task_assignment_timeout( data = None task_name_out = "" while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request( + client, fl_ctx + ) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map[client.name].task_send_count == 1 else: - _task_name_out, _client_task_id, _ = controller.process_task_request(client, fl_ctx) + _task_name_out, _client_task_id, _ = controller.communicator.process_task_request(client, fl_ctx) assert _task_name_out == "" assert _client_task_id == "" # client side running some logic to generate result if expected_client_to_get_task: - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 result = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=expected_client_to_get_task, task_name=task_name_out, task_id=client_task_id, @@ -1821,7 +1826,7 @@ def test_process_request_and_submission_with_task_assignment_timeout( ) launch_thread.join() - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 self.teardown_system(controller, fl_ctx) @@ -1856,7 +1861,7 @@ def test_process_submission_after_first_client_task_result_timeout(self, method, task_name_out = "" old_client_task_id = "" while task_name_out == "": - task_name_out, old_client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, old_client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -1865,21 +1870,21 @@ def test_process_submission_after_first_client_task_result_timeout(self, method, time.sleep(task_result_timeout + 1) # same client ask should get the same task - task_name_out, client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) assert client_task_id == old_client_task_id assert task.last_client_task_map[clients[0].name].task_send_count == 2 time.sleep(task_result_timeout + 1) # second client ask should get a task since task_result_timeout passed - task_name_out, client_task_id_1, data = controller.process_task_request(clients[1], fl_ctx) + task_name_out, client_task_id_1, data = controller.communicator.process_task_request(clients[1], fl_ctx) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map[clients[1].name].task_send_count == 1 # then we get back first client's result result = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=clients[0], task_name=task_name_out, task_id=client_task_id, @@ -1889,7 +1894,7 @@ def test_process_submission_after_first_client_task_result_timeout(self, method, # need to make sure the header is set assert result.get_header(ReservedHeaderKey.REPLY_IS_LATE) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 self.teardown_system(controller, fl_ctx) @@ -1925,7 +1930,7 @@ def test_process_submission_all_client_task_result_timeout(self, method, send_or task_name_out = "" while task_name_out == "": - task_name_out, old_client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, old_client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -1951,7 +1956,7 @@ def _assert_other_clients_get_no_task(controller, fl_ctx, client_idx: int, clien for i, client in enumerate(clients): if i == client_idx: continue - _task_name_out, _client_task_id, data = controller.process_task_request(client, fl_ctx) + _task_name_out, _client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert _task_name_out == "" assert _client_task_id == "" @@ -2019,12 +2024,12 @@ def test_process_task_request_client_not_in_target_get_nothing(self, method, sen assert controller.get_num_standing_tasks() == 1 # this client not in target so should get nothing - _task_name_out, _client_task_id, data = controller.process_task_request(client, fl_ctx) + _task_name_out, _client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert _task_name_out == "" assert _client_task_id == "" controller.cancel_task(task) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -2054,7 +2059,9 @@ def test_process_task_request_expected_client_get_task_and_unexpected_clients_ge task_name_out = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(targets[client_idx], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request( + targets[client_idx], fl_ctx + ) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -2065,7 +2072,7 @@ def test_process_task_request_expected_client_get_task_and_unexpected_clients_ge controller.cancel_task(task) assert task.completion_status == TaskCompletionStatus.CANCELLED - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -2113,13 +2120,13 @@ def test_process_task_request_with_task_assignment_timeout_expected_client_get_t if client.name == expected_client_to_get_task: task_name_out = "" while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map[client.name].task_send_count == 1 else: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -2153,7 +2160,7 @@ def test_send_only_one_task_and_exit_when_client_task_done(self, method, num_of_ client_task_id = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -2162,14 +2169,14 @@ def test_send_only_one_task_and_exit_when_client_task_done(self, method, num_of_ # once a client gets a task, other clients should not get task _assert_other_clients_get_no_task(controller=controller, fl_ctx=fl_ctx, client_idx=0, clients=clients) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 - controller.process_submission( + controller.communicator.process_submission( client=clients[0], task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=data ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() diff --git a/tests/unit_test/app_common/workflow/mock_statistics_controller.py b/tests/unit_test/app_common/workflow/mock_statistics_controller.py index f6bc0e8eb9..9f6d850035 100644 --- a/tests/unit_test/app_common/workflow/mock_statistics_controller.py +++ b/tests/unit_test/app_common/workflow/mock_statistics_controller.py @@ -15,8 +15,8 @@ from typing import Dict from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.workflows.statistics_controller import StatisticsController From 672d74b82b68ba165891faa534b31e2cb02267b6 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Fri, 8 Mar 2024 12:41:16 -0800 Subject: [PATCH 3/3] remove responder, address comments --- nvflare/apis/event_type.py | 3 +- nvflare/apis/impl/controller.py | 31 ++--- nvflare/apis/impl/wf_comm_server.py | 12 +- nvflare/apis/responder.py | 101 -------------- nvflare/apis/wf_comm_spec.py | 127 ++++++++++++++++-- nvflare/app_common/ccwf/server_ctl.py | 2 +- .../private/fed/server/server_json_config.py | 8 +- nvflare/private/fed/server/server_runner.py | 8 +- tests/unit_test/apis/impl/controller_test.py | 51 +++---- 9 files changed, 176 insertions(+), 167 deletions(-) delete mode 100644 nvflare/apis/responder.py diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 929c942e4c..4c65843e70 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -37,7 +37,8 @@ class EventType(object): BEFORE_PULL_TASK = "_before_pull_task" AFTER_PULL_TASK = "_after_pull_task" - BEFORE_PROCESS_TASK = "_before_process_task_request" + BEFORE_PROCESS_TASK_REQUEST = "_before_process_task_request" + AFTER_PROCESS_TASK_REQUEST = "_after_process_task_request" BEFORE_PROCESS_SUBMISSION = "_before_process_submission" AFTER_PROCESS_SUBMISSION = "_after_process_submission" diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 7f9f15bd5e..92d190805b 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -35,13 +35,7 @@ def __init__(self, task_check_period=0.2): self._task_check_period = task_check_period self.communicator = None - def set_communicator(self, communicator: WFCommSpec, fl_ctx: FLContext): - if not isinstance(communicator, WFCommSpec): - raise TypeError(f"communicator must be an instance of WFCommSpec, but got {type(communicator)}") - - self.communicator = communicator - self.communicator.controller = self - self.communicator.task_check_period = self._task_check_period + def initialize(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() if not engine: self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx) @@ -49,6 +43,14 @@ def set_communicator(self, communicator: WFCommSpec, fl_ctx: FLContext): self._engine = engine + def set_communicator(self, communicator: WFCommSpec): + if not isinstance(communicator, WFCommSpec): + raise TypeError(f"communicator must be an instance of WFCommSpec, but got {type(communicator)}") + + self.communicator = communicator + self.communicator.controller = self + self.communicator.task_check_period = self._task_check_period + def broadcast( self, task: Task, @@ -133,21 +135,12 @@ def relay_and_wait( ) def get_num_standing_tasks(self) -> int: - try: - return self.communicator.get_num_standing_tasks() - except: - raise NotImplementedError(f"{self.communicator} does not support this function") + return self.communicator.get_num_standing_tasks() def cancel_task( self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None ): - try: - self.communicator.cancel_task(task, completion_status, fl_ctx) - except: - raise NotImplementedError(f"{self.communicator} does not support this function") + self.communicator.cancel_task(task, completion_status, fl_ctx) def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): - try: - self.communicator.cancel_all_tasks(completion_status, fl_ctx) - except: - raise NotImplementedError(f"{self.communicator} does not support this function") + self.communicator.cancel_all_tasks(completion_status, fl_ctx) diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py index a1ffc1f336..9ecbc05feb 100644 --- a/nvflare/apis/impl/wf_comm_server.py +++ b/nvflare/apis/impl/wf_comm_server.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,10 +18,10 @@ from nvflare.apis.client import Client from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus +from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext from nvflare.apis.job_def import job_from_meta -from nvflare.apis.responder import Responder from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_copy from nvflare.apis.signal import Signal from nvflare.apis.wf_comm_spec import WFCommSpec @@ -78,7 +78,7 @@ def _get_client_task(target, task: Task): return None -class WFCommServer(Responder, WFCommSpec): +class WFCommServer(FLComponent, WFCommSpec): def __init__(self, task_check_period=0.2): """Manage life cycles of tasks and their destinations. @@ -96,7 +96,7 @@ def __init__(self, task_check_period=0.2): self._task_check_period = task_check_period self._dead_client_reports = {} # clients that reported the job is dead on it: name => report time self._dead_clients_lock = Lock() # need lock since dead_clients can be modified from different threads - # make sure _check_tasks, process_task_request, process_submission does not interfere with each other + # make sure check_tasks, process_task_request, process_submission does not interfere with each other self._controller_lock = Lock() def initialize_run(self, fl_ctx: FLContext): @@ -822,14 +822,14 @@ def _monitor_tasks(self): while not self._all_done: should_abort_job = self._job_policy_violated() if not should_abort_job: - self._check_tasks() + self.check_tasks() else: with self._engine.new_context() as fl_ctx: self.system_panic("Aborting job due to deployment policy violation", fl_ctx) return time.sleep(self._task_check_period) - def _check_tasks(self): + def check_tasks(self): with self._controller_lock: self._do_check_tasks() diff --git a/nvflare/apis/responder.py b/nvflare/apis/responder.py deleted file mode 100644 index 9935c430ac..0000000000 --- a/nvflare/apis/responder.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Tuple - -from .client import Client -from .fl_component import FLComponent -from .fl_context import FLContext -from .shareable import Shareable - - -class Responder(FLComponent, ABC): - def __init__(self): - """Init the Responder. - - Base class for responding to clients. Controller is a subclass of Responder. - """ - FLComponent.__init__(self) - - @abstractmethod - def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]: - """Called by the Engine when a task request is received from a client. - - Args: - client: the Client that the task request is from - fl_ctx: the FLContext - - Returns: task name, task id, and task data - - """ - pass - - @abstractmethod - def handle_exception(self, task_id: str, fl_ctx: FLContext): - """Called after process_task_request returns, but exception occurs before task is sent out.""" - pass - - @abstractmethod - def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext): - """Called by the Engine to process the submitted result from a client. - - Args: - client: the Client that the submitted result is from - task_name: the name of the task - task_id: the id of the task - result: the Shareable result from the Client - fl_ctx: the FLContext - - """ - pass - - @abstractmethod - def process_task_check(self, task_id: str, fl_ctx: FLContext): - """Called by the Engine to check whether a specified task still exists. - Args: - task_id: the id of the task - fl_ctx: the FLContext - Returns: the ClientTask object if exists; None otherwise - """ - pass - - @abstractmethod - def handle_dead_job(self, client_name: str, fl_ctx: FLContext): - """Called by the Engine to handle the case that the job on the client is dead. - - Args: - client_name: name of the client on which the job is dead - fl_ctx: the FLContext - - """ - pass - - def initialize_run(self, fl_ctx: FLContext): - """Called when a new RUN is about to start. - - Args: - fl_ctx: FL context. It must contain 'job_id' that is to be initialized - - """ - pass - - def finalize_run(self, fl_ctx: FLContext): - """Called when a new RUN is finished. - - Args: - fl_ctx: the FL context - - """ - pass diff --git a/nvflare/apis/wf_comm_spec.py b/nvflare/apis/wf_comm_spec.py index bb313e5849..8350b16b7a 100644 --- a/nvflare/apis/wf_comm_spec.py +++ b/nvflare/apis/wf_comm_spec.py @@ -13,11 +13,12 @@ # limitations under the License. from abc import ABC -from typing import List, Union +from typing import List, Optional, Tuple, Union from nvflare.apis.client import Client -from nvflare.apis.controller_spec import SendOrder, Task +from nvflare.apis.controller_spec import SendOrder, Task, TaskCompletionStatus from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal @@ -64,7 +65,7 @@ def broadcast( If == 0, end the task immediately after the min responses are received; """ - pass + raise NotImplementedError def broadcast_and_wait( self, @@ -91,7 +92,7 @@ def broadcast_and_wait( abort_signal: the abort signal. If triggered, this method stops waiting and returns to the caller. """ - pass + raise NotImplementedError def broadcast_forever( self, @@ -112,7 +113,7 @@ def broadcast_forever( targets: list of destination clients. None means all clients are determined dynamically. """ - pass + raise NotImplementedError def send( self, @@ -144,7 +145,7 @@ def send( task_assignment_timeout: in SEQUENTIAL order, this is the wait time for trying a target client, before trying next target. """ - pass + raise NotImplementedError def send_and_wait( self, @@ -170,7 +171,7 @@ def send_and_wait( abort_signal: the abort signal. If triggered, this method stops waiting and returns to the caller. """ - pass + raise NotImplementedError def relay( self, @@ -196,7 +197,7 @@ def relay( expanded dynamically when a new client joins. """ - pass + raise NotImplementedError def relay_and_wait( self, @@ -210,4 +211,112 @@ def relay_and_wait( abort_signal: Signal = None, ): """This is the blocking version of 'relay'.""" - pass + raise NotImplementedError + + def get_num_standing_tasks(self) -> int: + """Gets tasks that are currently standing. + + Returns: length of the list of standing tasks + + """ + raise NotImplementedError + + def cancel_task( + self, + task: Task, + completion_status: TaskCompletionStatus = TaskCompletionStatus.CANCELLED, + fl_ctx: Optional[FLContext] = None, + ): + """Cancels the specified task. + + If the task is standing, the task is cancelled immediately (and removed from job queue) and calls + the task_done CB (if specified); + + If the task is not standing, this method has no effect. + + Args: + task: the task to be cancelled + completion_status: the TaskCompletionStatus of the task + fl_ctx: the FL context + + """ + raise NotImplementedError + + def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): + """Cancels all standing tasks. + + Args: + completion_status: the TaskCompletionStatus of the task + fl_ctx: the FL context + """ + raise NotImplementedError + + def check_tasks(self): + """Checks if tasks should be exited.""" + raise NotImplementedError + + def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]: + """Called by the Engine when a task request is received from a client. + + Args: + client: the Client that the task request is from + fl_ctx: the FLContext + + Returns: task name, task id, and task data + + """ + raise NotImplementedError + + def handle_exception(self, task_id: str, fl_ctx: FLContext): + """Called after process_task_request returns, but exception occurs before task is sent out.""" + raise NotImplementedError + + def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext): + """Called by the Engine to process the submitted result from a client. + + Args: + client: the Client that the submitted result is from + task_name: the name of the task + task_id: the id of the task + result: the Shareable result from the Client + fl_ctx: the FLContext + + """ + raise NotImplementedError + + def process_task_check(self, task_id: str, fl_ctx: FLContext): + """Called by the Engine to check whether a specified task still exists. + Args: + task_id: the id of the task + fl_ctx: the FLContext + Returns: the ClientTask object if exists; None otherwise + """ + raise NotImplementedError + + def handle_dead_job(self, client_name: str, fl_ctx: FLContext): + """Called by the Engine to handle the case that the job on the client is dead. + + Args: + client_name: name of the client on which the job is dead + fl_ctx: the FLContext + + """ + raise NotImplementedError + + def initialize_run(self, fl_ctx: FLContext): + """Called when a new RUN is about to start. + + Args: + fl_ctx: FL context. It must contain 'job_id' that is to be initialized + + """ + raise NotImplementedError + + def finalize_run(self, fl_ctx: FLContext): + """Called when a new RUN is finished. + + Args: + fl_ctx: the FL context + + """ + raise NotImplementedError diff --git a/nvflare/app_common/ccwf/server_ctl.py b/nvflare/app_common/ccwf/server_ctl.py index 55f31643cc..80933969b5 100644 --- a/nvflare/app_common/ccwf/server_ctl.py +++ b/nvflare/app_common/ccwf/server_ctl.py @@ -344,7 +344,7 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): self.log_info(fl_ctx, f"Workflow {self.workflow_id} done!") def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.BEFORE_PROCESS_TASK: + if event_type == EventType.BEFORE_PROCESS_TASK_REQUEST: self._update_client_status(fl_ctx) def process_config_reply(self, client_name: str, reply: Shareable, fl_ctx: FLContext) -> bool: diff --git a/nvflare/private/fed/server/server_json_config.py b/nvflare/private/fed/server/server_json_config.py index c0ffac7c94..c735e82d44 100644 --- a/nvflare/private/fed/server/server_json_config.py +++ b/nvflare/private/fed/server/server_json_config.py @@ -17,6 +17,7 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import SystemConfigs, SystemVarName from nvflare.apis.impl.controller import Controller +from nvflare.apis.impl.wf_comm_server import WFCommServer from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node @@ -31,14 +32,17 @@ class WorkFlow: def __init__(self, id, controller: Controller): - """Workflow is a responder with ID. + """Workflow is a controller with ID. + + Setting communicator to WFCommServer for server-side workflow. Args: id: identification - responder (Responder): A responder + controller (Controller): A controller """ self.id = id self.controller = controller + self.controller.set_communicator(WFCommServer()) class ServerJsonConfigurator(FedJsonConfigurator): diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 7f5a028697..6b8886be47 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -20,7 +20,6 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FilterKey, FLContextKey, ReservedKey, ReservedTopic, ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.wf_comm_server import WFCommServer from nvflare.apis.server_engine_spec import ServerEngineSpec from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply from nvflare.apis.signal import Signal @@ -127,7 +126,7 @@ def _execute_run(self): fl_ctx.set_prop(FLContextKey.WORKFLOW, wf.id, sticky=True) - wf.controller.set_communicator(WFCommServer(), fl_ctx) + wf.controller.initialize(fl_ctx) wf.controller.communicator.initialize_run(fl_ctx) wf.controller.start_controller(fl_ctx) @@ -335,10 +334,13 @@ def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005): self.log_debug(fl_ctx, "no current workflow - asked client to try again later") return "", "", None - self.fire_event(EventType.BEFORE_PROCESS_TASK, fl_ctx) + self.log_debug(fl_ctx, "firing event EventType.BEFORE_PROCESS_TASK_REQUEST") + self.fire_event(EventType.BEFORE_PROCESS_TASK_REQUEST, fl_ctx) task_name, task_id, task_data = self.current_wf.controller.communicator.process_task_request( client, fl_ctx ) + self.log_debug(fl_ctx, "firing event EventType.AFTER_PROCESS_TASK_REQUEST") + self.fire_event(EventType.AFTER_PROCESS_TASK_REQUEST, fl_ctx) if task_name and task_name != SpecialTaskName.TRY_AGAIN: if task_data: diff --git a/tests/unit_test/apis/impl/controller_test.py b/tests/unit_test/apis/impl/controller_test.py index 439199e267..24e464f9bd 100644 --- a/tests/unit_test/apis/impl/controller_test.py +++ b/tests/unit_test/apis/impl/controller_test.py @@ -125,7 +125,8 @@ def _setup_system(num_clients=1): controller = DummyController() fl_ctx = mock_server_engine.new_context() communicator = WFCommServer() - controller.set_communicator(communicator, fl_ctx) + controller.set_communicator(communicator) + controller.initialize(fl_ctx) controller.communicator.initialize_run(fl_ctx=fl_ctx) return controller, mock_server_engine, fl_ctx, clients_list @@ -233,7 +234,7 @@ def test_check_task_remove_cancelled_tasks(self, method, num_of_start_tasks, num for i in range(num_of_cancel_tasks): controller.cancel_task(task=all_tasks[i], fl_ctx=fl_ctx) assert all_tasks[i].completion_status == TaskCompletionStatus.CANCELLED - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == (num_of_start_tasks - num_of_cancel_tasks) controller.cancel_all_tasks() for thread in all_threads: @@ -263,7 +264,7 @@ def test_client_request_after_cancel_task(self, method, num_client_requests): # check if task_id is empty means this task is not assigned assert task_id == "" assert data is None - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.CANCELLED launch_thread.join() @@ -645,7 +646,7 @@ def result_received_cb(client_task: ClientTask, **kwargs): ) assert task.last_client_task_map[client_name].result["_test_data"] == client_name - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -690,12 +691,12 @@ def test_task_done_cb(self, method, num_clients, task_name, input_data, cb, expe ) if task_complete == "timeout": time.sleep(timeout) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert task.completion_status == TaskCompletionStatus.TIMEOUT elif task_complete == "cancel": controller.cancel_task(task) assert task.completion_status == TaskCompletionStatus.CANCELLED - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert task.props[task_name] == expected assert controller.get_num_standing_tasks() == 0 launch_thread.join() @@ -867,7 +868,7 @@ def result_received_cb(client_task: ClientTask, fl_ctx: FLContext): controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=ctx, result=data ) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 1 new_task_name_out = "" while new_task_name_out == "": @@ -914,14 +915,14 @@ def result_received_cb(client_task: ClientTask, fl_ctx: FLContext): launch_thread.start() clients_pull_and_submit_result(controller=controller, ctx=ctx, clients=clients, task_name="__test_task") - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == num_of_clients for i in range(num_of_clients): clients_pull_and_submit_result( controller=controller, ctx=ctx, clients=clients, task_name=f"__new_test_task_{clients[i].name}" ) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == num_of_clients - (i + 1) launch_thread.join() @@ -1042,7 +1043,7 @@ def test_cancel_task(self, method): assert controller.get_num_standing_tasks() == 1 controller.cancel_task(task=task) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.CANCELLED launch_thread.join() @@ -1079,7 +1080,7 @@ def test_cancel_all_tasks(self, method): assert controller.get_num_standing_tasks() == 2 controller.cancel_all_tasks() - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.CANCELLED assert task1.completion_status == TaskCompletionStatus.CANCELLED @@ -1135,7 +1136,7 @@ def test_client_receive_only_one_task(self, method, num_of_clients): ) assert task.last_client_task_map[client.name].result == result - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -1208,13 +1209,13 @@ def test_task_only_exit_when_min_responses_received(self, method, min_responses) for client, client_task_id in zip(clients, client_task_ids): result = Shareable() - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 1 controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, result=result, fl_ctx=fl_ctx ) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() @@ -1257,7 +1258,7 @@ def test_task_exit_quickly_when_all_responses_received(self, method, min_respons client=client, task_name="__test_task", task_id=client_task_id, result=result, fl_ctx=fl_ctx ) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() @@ -1290,18 +1291,18 @@ def test_min_resp_is_zero_task_only_exit_when_all_client_task_done(self, method, task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) client_task_ids.append(client_task_id) assert task_name_out == "__test_task" - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 1 for client, client_task_id in zip(clients, client_task_ids): - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 1 result = Shareable() controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, result=result, fl_ctx=fl_ctx ) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() @@ -1814,7 +1815,7 @@ def test_process_request_and_submission_with_task_assignment_timeout( # client side running some logic to generate result if expected_client_to_get_task: - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 1 result = Shareable() controller.communicator.process_submission( @@ -1826,7 +1827,7 @@ def test_process_request_and_submission_with_task_assignment_timeout( ) launch_thread.join() - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 self.teardown_system(controller, fl_ctx) @@ -1894,7 +1895,7 @@ def test_process_submission_after_first_client_task_result_timeout(self, method, # need to make sure the header is set assert result.get_header(ReservedHeaderKey.REPLY_IS_LATE) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 1 self.teardown_system(controller, fl_ctx) @@ -2029,7 +2030,7 @@ def test_process_task_request_client_not_in_target_get_nothing(self, method, sen assert _client_task_id == "" controller.cancel_task(task) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -2072,7 +2073,7 @@ def test_process_task_request_expected_client_get_task_and_unexpected_clients_ge controller.cancel_task(task) assert task.completion_status == TaskCompletionStatus.CANCELLED - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -2169,14 +2170,14 @@ def test_send_only_one_task_and_exit_when_client_task_done(self, method, num_of_ # once a client gets a task, other clients should not get task _assert_other_clients_get_no_task(controller=controller, fl_ctx=fl_ctx, client_idx=0, clients=clients) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 1 controller.communicator.process_submission( client=clients[0], task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=data ) - controller.communicator._check_tasks() + controller.communicator.check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join()