Skip to content

Commit

Permalink
fix unit tests, clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed Mar 8, 2024
1 parent 047bf77 commit fc753bc
Show file tree
Hide file tree
Showing 19 changed files with 144 additions and 117 deletions.
1 change: 1 addition & 0 deletions nvflare/apis/event_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class EventType(object):

BEFORE_PULL_TASK = "_before_pull_task"
AFTER_PULL_TASK = "_after_pull_task"
BEFORE_PROCESS_TASK = "_before_process_task_request"
BEFORE_PROCESS_SUBMISSION = "_before_process_submission"
AFTER_PROCESS_SUBMISSION = "_after_process_submission"

Expand Down
36 changes: 22 additions & 14 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,33 @@
from typing import List, Optional, Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, ControllerSpec, SendOrder, Task, TaskCompletionStatus
from nvflare.apis.controller_spec import ControllerSpec, SendOrder, Task, TaskCompletionStatus
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.wf_comm_server import WFCommServer
from nvflare.apis.signal import Signal
from nvflare.apis.wf_comm_spec import WFCommSpec


class Controller(FLComponent, ControllerSpec, ABC):
def __init__(self, task_check_period=0.2):
"""Manage life cycles of tasks and their destinations.
"""Controller logic for tasks and their destinations.
Must set_communicator() to access communication related function implementations.
Args:
task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.2.
task_check_period (float, optional): interval for checking status of tasks. Applicable for WFCommServer. Defaults to 0.2.
"""
super().__init__()
self._task_check_period = task_check_period
self.communicator = None

def set_communicator(self, communicator: WFCommSpec, fl_ctx: FLContext):
communicator.task_check_period = self._task_check_period
if not isinstance(communicator, WFCommSpec):
raise TypeError(f"communicator must be an instance of WFCommSpec, but got {type(communicator)}")

self.communicator = communicator
self.communicator.controller = self
self.communicator.task_check_period = self._task_check_period
engine = fl_ctx.get_engine()
if not engine:
self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx)
Expand Down Expand Up @@ -128,18 +133,21 @@ def relay_and_wait(
)

def get_num_standing_tasks(self) -> int:
if not isinstance(self.communicator, WFCommServer):
raise NotImplementedError
return self.communicator.get_num_standing_tasks()
try:
return self.communicator.get_num_standing_tasks()
except:
raise NotImplementedError(f"{self.communicator} does not support this function")

def cancel_task(
self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None
):
if not isinstance(self.communicator, WFCommServer):
raise NotImplementedError
self.communicator.cancel_task(task, completion_status, fl_ctx)
try:
self.communicator.cancel_task(task, completion_status, fl_ctx)
except:
raise NotImplementedError(f"{self.communicator} does not support this function")

def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None):
if not isinstance(self.communicator, WFCommServer):
raise NotImplementedError
self.communicator.cancel_all_tasks(completion_status, fl_ctx)
try:
self.communicator.cancel_all_tasks(completion_status, fl_ctx)
except:
raise NotImplementedError(f"{self.communicator} does not support this function")
5 changes: 4 additions & 1 deletion nvflare/apis/impl/wf_comm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, task_check_period=0.2):
task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.2.
"""
super().__init__()
self.controller = None
self._engine = None
self._tasks = [] # list of standing tasks
self._client_task_map = {} # client_task_id => client_task
Expand Down Expand Up @@ -343,6 +344,8 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
if not self._dead_client_reports.get(client_name):
self._dead_client_reports[client_name] = time.time()

self.controller.handle_dead_job(client_name, fl_ctx)

def process_task_check(self, task_id: str, fl_ctx: FLContext):
with self._task_lock:
# task_id is the uuid associated with the client_task
Expand Down Expand Up @@ -397,7 +400,7 @@ def _do_process_submission(
if client_task is None:
# cannot find a standing task for the submission
self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id))
self.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx)
self.controller.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx)
return

task = client_task.task
Expand Down
1 change: 0 additions & 1 deletion nvflare/apis/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from abc import ABC, abstractmethod
from typing import Tuple


from .client import Client
from .fl_component import FLComponent
from .fl_context import FLContext
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/ccwf/cse_server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import os

from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.dxo import from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Task
from nvflare.apis.shareable import ReturnCode, Shareable
from nvflare.apis.signal import Signal
from nvflare.apis.workspace import Workspace
Expand Down
10 changes: 6 additions & 4 deletions nvflare/app_common/ccwf/server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from datetime import datetime

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import ReturnCode, Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
Expand Down Expand Up @@ -341,9 +343,9 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):

self.log_info(fl_ctx, f"Workflow {self.workflow_id} done!")

def process_task_request(self, client: Client, fl_ctx: FLContext):
self._update_client_status(fl_ctx)
return super().process_task_request(client, fl_ctx)
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.BEFORE_PROCESS_TASK:
self._update_client_status(fl_ctx)

def process_config_reply(self, client_name: str, reply: Shareable, fl_ctx: FLContext) -> bool:
return True
Expand Down
3 changes: 2 additions & 1 deletion nvflare/app_common/workflows/broadcast_and_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from typing import Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.response_processor import ResponseProcessor
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/workflows/broadcast_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from typing import Dict, List, Optional, Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.dxo import DXO, from_shareable
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Task
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.workflows.error_handling_controller import ErrorHandlingController
Expand Down
3 changes: 2 additions & 1 deletion nvflare/app_common/workflows/cyclic_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
import random

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor
Expand Down
4 changes: 2 additions & 2 deletions nvflare/app_common/workflows/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from typing import List, Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import OperatorMethod, TaskOperatorKey
from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.fl_model import FLModel, ParamsType
Expand Down
4 changes: 2 additions & 2 deletions nvflare/app_common/workflows/scatter_and_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from typing import Any

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import OperatorMethod, TaskOperatorKey
from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.aggregator import Aggregator
Expand Down
3 changes: 2 additions & 1 deletion nvflare/app_common/workflows/splitnn_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor
Expand Down
3 changes: 2 additions & 1 deletion nvflare/app_common/workflows/statistics_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from typing import Callable, Dict, List, Optional

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.dxo import from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.statistics_spec import Bin, Histogram, StatisticConfig
Expand Down
3 changes: 2 additions & 1 deletion nvflare/app_common/xgb/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import time

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.xgb.adaptors.adaptor import XGBServerAdaptor
Expand Down
1 change: 1 addition & 0 deletions nvflare/private/fed/server/server_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005):
self.log_debug(fl_ctx, "no current workflow - asked client to try again later")
return "", "", None

self.fire_event(EventType.BEFORE_PROCESS_TASK, fl_ctx)
task_name, task_id, task_data = self.current_wf.controller.communicator.process_task_request(
client, fl_ctx
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

import traceback

from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Task
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.aggregator import Aggregator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor
Expand Down
Loading

0 comments on commit fc753bc

Please sign in to comment.