From fc753bc43187d4730f15cb9bbc5a30d011315475 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Thu, 7 Mar 2024 17:03:08 -0800 Subject: [PATCH] fix unit tests, clean up --- nvflare/apis/event_type.py | 1 + nvflare/apis/impl/controller.py | 36 ++-- nvflare/apis/impl/wf_comm_server.py | 5 +- nvflare/apis/responder.py | 1 - nvflare/app_common/ccwf/cse_server_ctl.py | 2 +- nvflare/app_common/ccwf/server_ctl.py | 10 +- .../workflows/broadcast_and_process.py | 3 +- .../workflows/broadcast_operator.py | 2 +- nvflare/app_common/workflows/cyclic_ctl.py | 3 +- .../app_common/workflows/model_controller.py | 4 +- .../workflows/scatter_and_gather.py | 4 +- .../app_common/workflows/splitnn_workflow.py | 3 +- .../workflows/statistics_controller.py | 3 +- nvflare/app_common/xgb/controller.py | 3 +- nvflare/private/fed/server/server_runner.py | 1 + .../autofedrl/autofedrl_scatter_and_gather.py | 2 +- .../app/custom/custom_controller.py | 3 +- tests/unit_test/apis/impl/controller_test.py | 173 +++++++++--------- .../workflow/mock_statistics_controller.py | 2 +- 19 files changed, 144 insertions(+), 117 deletions(-) diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 942da170c5..929c942e4c 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -37,6 +37,7 @@ class EventType(object): BEFORE_PULL_TASK = "_before_pull_task" AFTER_PULL_TASK = "_after_pull_task" + BEFORE_PROCESS_TASK = "_before_process_task_request" BEFORE_PROCESS_SUBMISSION = "_before_process_submission" AFTER_PROCESS_SUBMISSION = "_after_process_submission" diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 1297f35587..7f9f15bd5e 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -15,28 +15,33 @@ from typing import List, Optional, Union from nvflare.apis.client import Client -from nvflare.apis.controller_spec import ClientTask, ControllerSpec, SendOrder, Task, TaskCompletionStatus +from nvflare.apis.controller_spec import ControllerSpec, SendOrder, Task, TaskCompletionStatus from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.wf_comm_server import WFCommServer from nvflare.apis.signal import Signal from nvflare.apis.wf_comm_spec import WFCommSpec class Controller(FLComponent, ControllerSpec, ABC): def __init__(self, task_check_period=0.2): - """Manage life cycles of tasks and their destinations. + """Controller logic for tasks and their destinations. + + Must set_communicator() to access communication related function implementations. Args: - task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.2. + task_check_period (float, optional): interval for checking status of tasks. Applicable for WFCommServer. Defaults to 0.2. """ super().__init__() self._task_check_period = task_check_period self.communicator = None def set_communicator(self, communicator: WFCommSpec, fl_ctx: FLContext): - communicator.task_check_period = self._task_check_period + if not isinstance(communicator, WFCommSpec): + raise TypeError(f"communicator must be an instance of WFCommSpec, but got {type(communicator)}") + self.communicator = communicator + self.communicator.controller = self + self.communicator.task_check_period = self._task_check_period engine = fl_ctx.get_engine() if not engine: self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx) @@ -128,18 +133,21 @@ def relay_and_wait( ) def get_num_standing_tasks(self) -> int: - if not isinstance(self.communicator, WFCommServer): - raise NotImplementedError - return self.communicator.get_num_standing_tasks() + try: + return self.communicator.get_num_standing_tasks() + except: + raise NotImplementedError(f"{self.communicator} does not support this function") def cancel_task( self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None ): - if not isinstance(self.communicator, WFCommServer): - raise NotImplementedError - self.communicator.cancel_task(task, completion_status, fl_ctx) + try: + self.communicator.cancel_task(task, completion_status, fl_ctx) + except: + raise NotImplementedError(f"{self.communicator} does not support this function") def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): - if not isinstance(self.communicator, WFCommServer): - raise NotImplementedError - self.communicator.cancel_all_tasks(completion_status, fl_ctx) + try: + self.communicator.cancel_all_tasks(completion_status, fl_ctx) + except: + raise NotImplementedError(f"{self.communicator} does not support this function") diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py index b15f555c1a..a1ffc1f336 100644 --- a/nvflare/apis/impl/wf_comm_server.py +++ b/nvflare/apis/impl/wf_comm_server.py @@ -86,6 +86,7 @@ def __init__(self, task_check_period=0.2): task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.2. """ super().__init__() + self.controller = None self._engine = None self._tasks = [] # list of standing tasks self._client_task_map = {} # client_task_id => client_task @@ -343,6 +344,8 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): if not self._dead_client_reports.get(client_name): self._dead_client_reports[client_name] = time.time() + self.controller.handle_dead_job(client_name, fl_ctx) + def process_task_check(self, task_id: str, fl_ctx: FLContext): with self._task_lock: # task_id is the uuid associated with the client_task @@ -397,7 +400,7 @@ def _do_process_submission( if client_task is None: # cannot find a standing task for the submission self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id)) - self.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx) + self.controller.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx) return task = client_task.task diff --git a/nvflare/apis/responder.py b/nvflare/apis/responder.py index 21dc51a257..9935c430ac 100644 --- a/nvflare/apis/responder.py +++ b/nvflare/apis/responder.py @@ -15,7 +15,6 @@ from abc import ABC, abstractmethod from typing import Tuple - from .client import Client from .fl_component import FLComponent from .fl_context import FLContext diff --git a/nvflare/app_common/ccwf/cse_server_ctl.py b/nvflare/app_common/ccwf/cse_server_ctl.py index b55dc41800..f95e5fdc05 100644 --- a/nvflare/app_common/ccwf/cse_server_ctl.py +++ b/nvflare/app_common/ccwf/cse_server_ctl.py @@ -14,9 +14,9 @@ import os +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.dxo import from_shareable from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Task from nvflare.apis.shareable import ReturnCode, Shareable from nvflare.apis.signal import Signal from nvflare.apis.workspace import Workspace diff --git a/nvflare/app_common/ccwf/server_ctl.py b/nvflare/app_common/ccwf/server_ctl.py index 808acf9d33..55f31643cc 100644 --- a/nvflare/app_common/ccwf/server_ctl.py +++ b/nvflare/app_common/ccwf/server_ctl.py @@ -16,9 +16,11 @@ from datetime import datetime from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task +from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import ReturnCode, Shareable from nvflare.apis.signal import Signal from nvflare.app_common.app_constant import AppConstants @@ -341,9 +343,9 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): self.log_info(fl_ctx, f"Workflow {self.workflow_id} done!") - def process_task_request(self, client: Client, fl_ctx: FLContext): - self._update_client_status(fl_ctx) - return super().process_task_request(client, fl_ctx) + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.BEFORE_PROCESS_TASK: + self._update_client_status(fl_ctx) def process_config_reply(self, client_name: str, reply: Shareable, fl_ctx: FLContext) -> bool: return True diff --git a/nvflare/app_common/workflows/broadcast_and_process.py b/nvflare/app_common/workflows/broadcast_and_process.py index fd2a8216aa..d015817c30 100644 --- a/nvflare/app_common/workflows/broadcast_and_process.py +++ b/nvflare/app_common/workflows/broadcast_and_process.py @@ -15,8 +15,9 @@ from typing import Union from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.response_processor import ResponseProcessor diff --git a/nvflare/app_common/workflows/broadcast_operator.py b/nvflare/app_common/workflows/broadcast_operator.py index 8827e6209f..ce548cd95e 100644 --- a/nvflare/app_common/workflows/broadcast_operator.py +++ b/nvflare/app_common/workflows/broadcast_operator.py @@ -16,11 +16,11 @@ from typing import Dict, List, Optional, Union from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.dxo import DXO, from_shareable from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Task from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.workflows.error_handling_controller import ErrorHandlingController diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index 754e1b06b6..65f8cb9195 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -16,9 +16,10 @@ import random from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index 0b65f07539..cd6b565241 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -16,10 +16,10 @@ from typing import List, Union from nvflare.apis.client import Client -from nvflare.apis.controller_spec import OperatorMethod, TaskOperatorKey +from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.fl_model import FLModel, ParamsType diff --git a/nvflare/app_common/workflows/scatter_and_gather.py b/nvflare/app_common/workflows/scatter_and_gather.py index e663d81e5a..2255c6a160 100644 --- a/nvflare/app_common/workflows/scatter_and_gather.py +++ b/nvflare/app_common/workflows/scatter_and_gather.py @@ -15,10 +15,10 @@ from typing import Any from nvflare.apis.client import Client -from nvflare.apis.controller_spec import OperatorMethod, TaskOperatorKey +from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.aggregator import Aggregator diff --git a/nvflare/app_common/workflows/splitnn_workflow.py b/nvflare/app_common/workflows/splitnn_workflow.py index ae3e8a17a8..b16df0e837 100644 --- a/nvflare/app_common/workflows/splitnn_workflow.py +++ b/nvflare/app_common/workflows/splitnn_workflow.py @@ -13,9 +13,10 @@ # limitations under the License. from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor diff --git a/nvflare/app_common/workflows/statistics_controller.py b/nvflare/app_common/workflows/statistics_controller.py index 1c83bba99d..bc59aefbd8 100644 --- a/nvflare/app_common/workflows/statistics_controller.py +++ b/nvflare/app_common/workflows/statistics_controller.py @@ -16,11 +16,12 @@ from typing import Callable, Dict, List, Optional from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.dxo import from_shareable from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.statistics_spec import Bin, Histogram, StatisticConfig diff --git a/nvflare/app_common/xgb/controller.py b/nvflare/app_common/xgb/controller.py index 83bd6f90a1..5e96f95aa2 100644 --- a/nvflare/app_common/xgb/controller.py +++ b/nvflare/app_common/xgb/controller.py @@ -15,8 +15,9 @@ import time from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import ReturnCode, Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.xgb.adaptors.adaptor import XGBServerAdaptor diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 224d2f6304..7f5a028697 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -335,6 +335,7 @@ def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005): self.log_debug(fl_ctx, "no current workflow - asked client to try again later") return "", "", None + self.fire_event(EventType.BEFORE_PROCESS_TASK, fl_ctx) task_name, task_id, task_data = self.current_wf.controller.communicator.process_task_request( client, fl_ctx ) diff --git a/research/auto-fed-rl/src/autofedrl/autofedrl_scatter_and_gather.py b/research/auto-fed-rl/src/autofedrl/autofedrl_scatter_and_gather.py index 20a44f064b..6d2e3bd8a3 100644 --- a/research/auto-fed-rl/src/autofedrl/autofedrl_scatter_and_gather.py +++ b/research/auto-fed-rl/src/autofedrl/autofedrl_scatter_and_gather.py @@ -14,10 +14,10 @@ import traceback +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Task from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.aggregator import Aggregator diff --git a/tests/integration_test/data/apps/tb_streaming/app/custom/custom_controller.py b/tests/integration_test/data/apps/tb_streaming/app/custom/custom_controller.py index dd36aad0f3..9a585eb81c 100755 --- a/tests/integration_test/data/apps/tb_streaming/app/custom/custom_controller.py +++ b/tests/integration_test/data/apps/tb_streaming/app/custom/custom_controller.py @@ -13,8 +13,9 @@ # limitations under the License. from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor diff --git a/tests/unit_test/apis/impl/controller_test.py b/tests/unit_test/apis/impl/controller_test.py index 2cfc656d44..439199e267 100644 --- a/tests/unit_test/apis/impl/controller_test.py +++ b/tests/unit_test/apis/impl/controller_test.py @@ -26,6 +26,7 @@ from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus from nvflare.apis.fl_context import FLContext, FLContextManager from nvflare.apis.impl.controller import Controller +from nvflare.apis.impl.wf_comm_server import WFCommServer from nvflare.apis.server_engine_spec import ServerEngineSpec from nvflare.apis.shareable import ReservedHeaderKey, Shareable from nvflare.apis.signal import Signal @@ -123,7 +124,9 @@ def _setup_system(num_clients=1): controller = DummyController() fl_ctx = mock_server_engine.new_context() - controller.initialize_run(fl_ctx=fl_ctx) + communicator = WFCommServer() + controller.set_communicator(communicator, fl_ctx) + controller.communicator.initialize_run(fl_ctx=fl_ctx) return controller, mock_server_engine, fl_ctx, clients_list @@ -139,7 +142,7 @@ def setup_system(num_of_clients=1): @staticmethod def teardown_system(controller, fl_ctx): - controller.finalize_run(fl_ctx=fl_ctx) + controller.communicator.finalize_run(fl_ctx=fl_ctx) class TestTaskManagement(TestController): @@ -230,7 +233,7 @@ def test_check_task_remove_cancelled_tasks(self, method, num_of_start_tasks, num for i in range(num_of_cancel_tasks): controller.cancel_task(task=all_tasks[i], fl_ctx=fl_ctx) assert all_tasks[i].completion_status == TaskCompletionStatus.CANCELLED - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == (num_of_start_tasks - num_of_cancel_tasks) controller.cancel_all_tasks() for thread in all_threads: @@ -256,11 +259,11 @@ def test_client_request_after_cancel_task(self, method, num_client_requests): get_ready(launch_thread) controller.cancel_task(task) for i in range(num_client_requests): - _, task_id, data = controller.process_task_request(client, fl_ctx) + _, task_id, data = controller.communicator.process_task_request(client, fl_ctx) # check if task_id is empty means this task is not assigned assert task_id == "" assert data is None - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.CANCELLED launch_thread.join() @@ -282,18 +285,18 @@ def test_client_submit_result_after_cancel_task(self, method): }, ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) controller.cancel_task(task) assert task.completion_status == TaskCompletionStatus.CANCELLED time.sleep(1) - print(controller._tasks) + print(controller.communicator._tasks) # in here we make up client results: result = Shareable() result["result"] = "result" with pytest.raises(RuntimeError, match="Unknown task: __test_task from client __test_client0."): - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=result ) @@ -531,7 +534,7 @@ def test_process_submission_invalid_input(self, method, kwargs, error, msg): controller, fl_ctx, clients = self.setup_system() with pytest.raises(error, match=msg): - controller.process_submission(**kwargs) + controller.communicator.process_submission(**kwargs) self.teardown_system(controller, fl_ctx) @@ -573,13 +576,13 @@ def clients_pull_and_submit_result(controller, ctx, clients, task_name): client_task_ids = [] num_of_clients = len(clients) for i in range(num_of_clients): - task_name_out, client_task_id, data = controller.process_task_request(clients[i], ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[i], ctx) assert task_name_out == task_name client_task_ids.append(client_task_id) for client, client_task_id in zip(clients, client_task_ids): data = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=client, task_name=task_name, task_id=client_task_id, fl_ctx=ctx, result=data ) @@ -606,7 +609,7 @@ def before_task_sent_cb(client_task: ClientTask, **kwargs): ) get_ready(launch_thread) - task_name_out, _, data = controller.process_task_request(client, fl_ctx) + task_name_out, _, data = controller.communicator.process_task_request(client, fl_ctx) assert data["_test_data"] == client_name controller.cancel_task(task) @@ -636,13 +639,13 @@ def result_received_cb(client_task: ClientTask, **kwargs): }, ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) - controller.process_submission( + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=data ) assert task.last_client_task_map[client_name].result["_test_data"] == client_name - controller._check_tasks() + controller.communicator._check_tasks() assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -670,7 +673,7 @@ def test_task_done_cb(self, method, num_clients, task_name, input_data, cb, expe client_task_ids = len(clients) * [None] for i, client in enumerate(clients): - task_name_out, client_task_ids[i], _ = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_ids[i], _ = controller.communicator.process_task_request(client, fl_ctx) if task_name_out == "": client_task_ids[i] = None @@ -682,17 +685,17 @@ def test_task_done_cb(self, method, num_clients, task_name, input_data, cb, expe for client, client_task_id in zip(clients, client_task_ids): if client_task_id is not None: if task_complete == "normal": - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=result ) if task_complete == "timeout": time.sleep(timeout) - controller._check_tasks() + controller.communicator._check_tasks() assert task.completion_status == TaskCompletionStatus.TIMEOUT elif task_complete == "cancel": controller.cancel_task(task) assert task.completion_status == TaskCompletionStatus.CANCELLED - controller._check_tasks() + controller.communicator._check_tasks() assert task.props[task_name] == expected assert controller.get_num_standing_tasks() == 0 launch_thread.join() @@ -718,7 +721,7 @@ def before_task_sent_cb(client_task: ClientTask, **kwargs): ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -748,16 +751,16 @@ def result_received_cb(client_task: ClientTask, **kwargs): ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client1, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client1, fl_ctx) result = Shareable() result["__result"] = "__test_result" - controller.process_submission( + controller.communicator.process_submission( client=client1, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=result ) assert task.last_client_task_map["__test_client0"].result == result - task_name_out, client_task_id, data = controller.process_task_request(client2, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client2, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -803,12 +806,12 @@ def before_task_sent_cb(client_task: ClientTask, fl_ctx: FLContext): task_name_out = "" while task_name_out == "": - task_name_out, _, _ = controller.process_task_request(client, ctx) + task_name_out, _, _ = controller.communicator.process_task_request(client, ctx) time.sleep(0.1) assert task_name_out == "__test_task" new_task_name_out = "" while new_task_name_out == "": - new_task_name_out, _, _ = controller.process_task_request(client, ctx) + new_task_name_out, _, _ = controller.communicator.process_task_request(client, ctx) time.sleep(0.1) assert new_task_name_out == "__new_test_task" @@ -857,18 +860,18 @@ def result_received_cb(client_task: ClientTask, fl_ctx: FLContext): client_task_id = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(client, ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, ctx) time.sleep(0.1) assert task_name_out == "__test_task" - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=ctx, result=data ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 new_task_name_out = "" while new_task_name_out == "": - new_task_name_out, _, _ = controller.process_task_request(client, ctx) + new_task_name_out, _, _ = controller.communicator.process_task_request(client, ctx) time.sleep(0.1) assert new_task_name_out == "__new_test_task" launch_thread.join() @@ -911,14 +914,14 @@ def result_received_cb(client_task: ClientTask, fl_ctx: FLContext): launch_thread.start() clients_pull_and_submit_result(controller=controller, ctx=ctx, clients=clients, task_name="__test_task") - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == num_of_clients for i in range(num_of_clients): clients_pull_and_submit_result( controller=controller, ctx=ctx, clients=clients, task_name=f"__new_test_task_{clients[i].name}" ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == num_of_clients - (i + 1) launch_thread.join() @@ -931,7 +934,7 @@ def test_process_submission_invalid_task(self, task_name, client_name): controller, fl_ctx, clients = self.setup_system() client = clients[0] with pytest.raises(RuntimeError, match=f"Unknown task: {task_name} from client {client_name}."): - controller.process_submission( + controller.communicator.process_submission( client=client, task_name=task_name, task_id=str(uuid.uuid4()), fl_ctx=FLContext(), result=Shareable() ) self.teardown_system(controller, fl_ctx) @@ -957,7 +960,7 @@ def test_process_task_request_client_request_multiple_times(self, method, num_cl get_ready(launch_thread) for i in range(num_client_requests): - task_name_out, _, data = controller.process_task_request(client, fl_ctx) + task_name_out, _, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map["__test_client0"].task_send_count == num_client_requests @@ -982,12 +985,12 @@ def test_process_submission(self, method): ) get_ready(launch_thread) - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) # in here we make up client results: result = Shareable() result["result"] = "result" - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=result ) assert task.last_client_task_map["__test_client0"].result == result @@ -1039,7 +1042,7 @@ def test_cancel_task(self, method): assert controller.get_num_standing_tasks() == 1 controller.cancel_task(task=task) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.CANCELLED launch_thread.join() @@ -1076,7 +1079,7 @@ def test_cancel_all_tasks(self, method): assert controller.get_num_standing_tasks() == 2 controller.cancel_all_tasks() - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.CANCELLED assert task1.completion_status == TaskCompletionStatus.CANCELLED @@ -1111,19 +1114,19 @@ def test_client_receive_only_one_task(self, method, num_of_clients): client_task_id = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map[client.name].task_send_count == 1 assert controller.get_num_standing_tasks() == 1 - _, next_client_task_id, _ = controller.process_task_request(client, fl_ctx) + _, next_client_task_id, _ = controller.communicator.process_task_request(client, fl_ctx) assert next_client_task_id == client_task_id assert task.last_client_task_map[client.name].task_send_count == 2 result = Shareable() result["result"] = "result" - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, @@ -1132,7 +1135,7 @@ def test_client_receive_only_one_task(self, method, num_of_clients): ) assert task.last_client_task_map[client.name].result == result - controller._check_tasks() + controller.communicator._check_tasks() assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -1160,7 +1163,7 @@ def test_only_client_in_target_will_get_task(self, method, num_of_clients): task_name_out = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -1168,7 +1171,7 @@ def test_only_client_in_target_will_get_task(self, method, num_of_clients): assert controller.get_num_standing_tasks() == 1 for client in clients[1:]: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -1199,19 +1202,19 @@ def test_task_only_exit_when_min_responses_received(self, method, min_responses) client_task_ids = [] for client in clients: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) client_task_ids.append(client_task_id) assert task_name_out == "__test_task" for client, client_task_id in zip(clients, client_task_ids): result = Shareable() - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, result=result, fl_ctx=fl_ctx ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() @@ -1244,17 +1247,17 @@ def test_task_exit_quickly_when_all_responses_received(self, method, min_respons client_task_ids = [] for client in clients: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) client_task_ids.append(client_task_id) assert task_name_out == "__test_task" for client, client_task_id in zip(clients, client_task_ids): result = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, result=result, fl_ctx=fl_ctx ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() @@ -1284,21 +1287,21 @@ def test_min_resp_is_zero_task_only_exit_when_all_client_task_done(self, method, client_task_ids = [] for client in clients: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) client_task_ids.append(client_task_id) assert task_name_out == "__test_task" - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 for client, client_task_id in zip(clients, client_task_ids): - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 result = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=client, task_name="__test_task", task_id=client_task_id, result=result, fl_ctx=fl_ctx ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() @@ -1604,7 +1607,7 @@ def test_only_client_in_target_will_get_task(self, method, send_order): task_name_out = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -1612,7 +1615,7 @@ def test_only_client_in_target_will_get_task(self, method, send_order): assert controller.get_num_standing_tasks() == 1 for client in clients[1:]: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -1650,7 +1653,7 @@ def test_task_assignment_timeout_sequential_order_only_client_in_target_will_get time.sleep(task_assignment_timeout + 1) for client in clients[1:]: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -1695,7 +1698,7 @@ def test_process_task_request( assert controller.get_num_standing_tasks() == 1 time.sleep(time_before_first_request) - task_name, task_id, data = controller.process_task_request(client=request_client, fl_ctx=fl_ctx) + task_name, task_id, data = controller.communicator.process_task_request(client=request_client, fl_ctx=fl_ctx) client_get_a_task = True if task_name == "__test_task" else False assert client_get_a_task == expected_to_get_task @@ -1729,7 +1732,7 @@ def test_sequential_sequence(self, method, targets): client_tasks_and_results = {} for c in targets: - task_name, task_id, data = controller.process_task_request(client=c, fl_ctx=fl_ctx) + task_name, task_id, data = controller.communicator.process_task_request(client=c, fl_ctx=fl_ctx) if task_name != "": client_result = Shareable() client_result["result"] = f"{c.name}" @@ -1740,7 +1743,7 @@ def test_sequential_sequence(self, method, targets): for task_id in client_tasks_and_results.keys(): c, task_name, client_result = client_tasks_and_results[task_id] task.data["result"] += client_result["result"] - controller.process_submission( + controller.communicator.process_submission( client=c, task_name=task_name, task_id=task_id, result=client_result, fl_ctx=fl_ctx ) assert task.last_client_task_map[c.name].result == client_result @@ -1797,22 +1800,24 @@ def test_process_request_and_submission_with_task_assignment_timeout( data = None task_name_out = "" while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request( + client, fl_ctx + ) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map[client.name].task_send_count == 1 else: - _task_name_out, _client_task_id, _ = controller.process_task_request(client, fl_ctx) + _task_name_out, _client_task_id, _ = controller.communicator.process_task_request(client, fl_ctx) assert _task_name_out == "" assert _client_task_id == "" # client side running some logic to generate result if expected_client_to_get_task: - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 result = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=expected_client_to_get_task, task_name=task_name_out, task_id=client_task_id, @@ -1821,7 +1826,7 @@ def test_process_request_and_submission_with_task_assignment_timeout( ) launch_thread.join() - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 self.teardown_system(controller, fl_ctx) @@ -1856,7 +1861,7 @@ def test_process_submission_after_first_client_task_result_timeout(self, method, task_name_out = "" old_client_task_id = "" while task_name_out == "": - task_name_out, old_client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, old_client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -1865,21 +1870,21 @@ def test_process_submission_after_first_client_task_result_timeout(self, method, time.sleep(task_result_timeout + 1) # same client ask should get the same task - task_name_out, client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) assert client_task_id == old_client_task_id assert task.last_client_task_map[clients[0].name].task_send_count == 2 time.sleep(task_result_timeout + 1) # second client ask should get a task since task_result_timeout passed - task_name_out, client_task_id_1, data = controller.process_task_request(clients[1], fl_ctx) + task_name_out, client_task_id_1, data = controller.communicator.process_task_request(clients[1], fl_ctx) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map[clients[1].name].task_send_count == 1 # then we get back first client's result result = Shareable() - controller.process_submission( + controller.communicator.process_submission( client=clients[0], task_name=task_name_out, task_id=client_task_id, @@ -1889,7 +1894,7 @@ def test_process_submission_after_first_client_task_result_timeout(self, method, # need to make sure the header is set assert result.get_header(ReservedHeaderKey.REPLY_IS_LATE) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 self.teardown_system(controller, fl_ctx) @@ -1925,7 +1930,7 @@ def test_process_submission_all_client_task_result_timeout(self, method, send_or task_name_out = "" while task_name_out == "": - task_name_out, old_client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, old_client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -1951,7 +1956,7 @@ def _assert_other_clients_get_no_task(controller, fl_ctx, client_idx: int, clien for i, client in enumerate(clients): if i == client_idx: continue - _task_name_out, _client_task_id, data = controller.process_task_request(client, fl_ctx) + _task_name_out, _client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert _task_name_out == "" assert _client_task_id == "" @@ -2019,12 +2024,12 @@ def test_process_task_request_client_not_in_target_get_nothing(self, method, sen assert controller.get_num_standing_tasks() == 1 # this client not in target so should get nothing - _task_name_out, _client_task_id, data = controller.process_task_request(client, fl_ctx) + _task_name_out, _client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert _task_name_out == "" assert _client_task_id == "" controller.cancel_task(task) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -2054,7 +2059,9 @@ def test_process_task_request_expected_client_get_task_and_unexpected_clients_ge task_name_out = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(targets[client_idx], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request( + targets[client_idx], fl_ctx + ) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -2065,7 +2072,7 @@ def test_process_task_request_expected_client_get_task_and_unexpected_clients_ge controller.cancel_task(task) assert task.completion_status == TaskCompletionStatus.CANCELLED - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 launch_thread.join() self.teardown_system(controller, fl_ctx) @@ -2113,13 +2120,13 @@ def test_process_task_request_with_task_assignment_timeout_expected_client_get_t if client.name == expected_client_to_get_task: task_name_out = "" while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data assert task.last_client_task_map[client.name].task_send_count == 1 else: - task_name_out, client_task_id, data = controller.process_task_request(client, fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(client, fl_ctx) assert task_name_out == "" assert client_task_id == "" @@ -2153,7 +2160,7 @@ def test_send_only_one_task_and_exit_when_client_task_done(self, method, num_of_ client_task_id = "" data = None while task_name_out == "": - task_name_out, client_task_id, data = controller.process_task_request(clients[0], fl_ctx) + task_name_out, client_task_id, data = controller.communicator.process_task_request(clients[0], fl_ctx) time.sleep(0.1) assert task_name_out == "__test_task" assert data == input_data @@ -2162,14 +2169,14 @@ def test_send_only_one_task_and_exit_when_client_task_done(self, method, num_of_ # once a client gets a task, other clients should not get task _assert_other_clients_get_no_task(controller=controller, fl_ctx=fl_ctx, client_idx=0, clients=clients) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 1 - controller.process_submission( + controller.communicator.process_submission( client=clients[0], task_name="__test_task", task_id=client_task_id, fl_ctx=fl_ctx, result=data ) - controller._check_tasks() + controller.communicator._check_tasks() assert controller.get_num_standing_tasks() == 0 assert task.completion_status == TaskCompletionStatus.OK launch_thread.join() diff --git a/tests/unit_test/app_common/workflow/mock_statistics_controller.py b/tests/unit_test/app_common/workflow/mock_statistics_controller.py index f6bc0e8eb9..9f6d850035 100644 --- a/tests/unit_test/app_common/workflow/mock_statistics_controller.py +++ b/tests/unit_test/app_common/workflow/mock_statistics_controller.py @@ -15,8 +15,8 @@ from typing import Dict from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import ClientTask from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal from nvflare.app_common.workflows.statistics_controller import StatisticsController