Skip to content

Commit

Permalink
remove responder, address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed Mar 8, 2024
1 parent 16b3119 commit 672d74b
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 167 deletions.
3 changes: 2 additions & 1 deletion nvflare/apis/event_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class EventType(object):

BEFORE_PULL_TASK = "_before_pull_task"
AFTER_PULL_TASK = "_after_pull_task"
BEFORE_PROCESS_TASK = "_before_process_task_request"
BEFORE_PROCESS_TASK_REQUEST = "_before_process_task_request"
AFTER_PROCESS_TASK_REQUEST = "_after_process_task_request"
BEFORE_PROCESS_SUBMISSION = "_before_process_submission"
AFTER_PROCESS_SUBMISSION = "_after_process_submission"

Expand Down
31 changes: 12 additions & 19 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,22 @@ def __init__(self, task_check_period=0.2):
self._task_check_period = task_check_period
self.communicator = None

def set_communicator(self, communicator: WFCommSpec, fl_ctx: FLContext):
if not isinstance(communicator, WFCommSpec):
raise TypeError(f"communicator must be an instance of WFCommSpec, but got {type(communicator)}")

self.communicator = communicator
self.communicator.controller = self
self.communicator.task_check_period = self._task_check_period
def initialize(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
if not engine:
self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx)
return

self._engine = engine

def set_communicator(self, communicator: WFCommSpec):
if not isinstance(communicator, WFCommSpec):
raise TypeError(f"communicator must be an instance of WFCommSpec, but got {type(communicator)}")

self.communicator = communicator
self.communicator.controller = self
self.communicator.task_check_period = self._task_check_period

def broadcast(
self,
task: Task,
Expand Down Expand Up @@ -133,21 +135,12 @@ def relay_and_wait(
)

def get_num_standing_tasks(self) -> int:
try:
return self.communicator.get_num_standing_tasks()
except:
raise NotImplementedError(f"{self.communicator} does not support this function")
return self.communicator.get_num_standing_tasks()

def cancel_task(
self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None
):
try:
self.communicator.cancel_task(task, completion_status, fl_ctx)
except:
raise NotImplementedError(f"{self.communicator} does not support this function")
self.communicator.cancel_task(task, completion_status, fl_ctx)

def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None):
try:
self.communicator.cancel_all_tasks(completion_status, fl_ctx)
except:
raise NotImplementedError(f"{self.communicator} does not support this function")
self.communicator.cancel_all_tasks(completion_status, fl_ctx)
12 changes: 6 additions & 6 deletions nvflare/apis/impl/wf_comm_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,10 +18,10 @@

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import job_from_meta
from nvflare.apis.responder import Responder
from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_copy
from nvflare.apis.signal import Signal
from nvflare.apis.wf_comm_spec import WFCommSpec
Expand Down Expand Up @@ -78,7 +78,7 @@ def _get_client_task(target, task: Task):
return None


class WFCommServer(Responder, WFCommSpec):
class WFCommServer(FLComponent, WFCommSpec):
def __init__(self, task_check_period=0.2):
"""Manage life cycles of tasks and their destinations.
Expand All @@ -96,7 +96,7 @@ def __init__(self, task_check_period=0.2):
self._task_check_period = task_check_period
self._dead_client_reports = {} # clients that reported the job is dead on it: name => report time
self._dead_clients_lock = Lock() # need lock since dead_clients can be modified from different threads
# make sure _check_tasks, process_task_request, process_submission does not interfere with each other
# make sure check_tasks, process_task_request, process_submission does not interfere with each other
self._controller_lock = Lock()

def initialize_run(self, fl_ctx: FLContext):
Expand Down Expand Up @@ -822,14 +822,14 @@ def _monitor_tasks(self):
while not self._all_done:
should_abort_job = self._job_policy_violated()
if not should_abort_job:
self._check_tasks()
self.check_tasks()
else:
with self._engine.new_context() as fl_ctx:
self.system_panic("Aborting job due to deployment policy violation", fl_ctx)
return
time.sleep(self._task_check_period)

def _check_tasks(self):
def check_tasks(self):
with self._controller_lock:
self._do_check_tasks()

Expand Down
101 changes: 0 additions & 101 deletions nvflare/apis/responder.py

This file was deleted.

Loading

0 comments on commit 672d74b

Please sign in to comment.