diff --git a/nvflare/private/fed/app/simulator/simulator.py b/nvflare/private/fed/app/simulator/simulator.py index 246cb61b68..50a5a56319 100644 --- a/nvflare/private/fed/app/simulator/simulator.py +++ b/nvflare/private/fed/app/simulator/simulator.py @@ -30,6 +30,12 @@ def define_simulator_parser(simulator_parser): simulator_parser.add_argument("-t", "--threads", type=int, help="number of parallel running clients") simulator_parser.add_argument("-gpu", "--gpu", type=str, help="list of GPU Device Ids, comma separated") simulator_parser.add_argument("-m", "--max_clients", type=int, default=100, help="max number of clients") + simulator_parser.add_argument( + "--end_run_for_all", + default=False, + action="store_true", + help="flag to indicate if running END_RUN event for all clients", + ) def run_simulator(simulator_args): @@ -41,6 +47,7 @@ def run_simulator(simulator_args): threads=simulator_args.threads, gpu=simulator_args.gpu, max_clients=simulator_args.max_clients, + end_run_for_all=simulator_args.end_run_for_all, ) run_status = simulator.run() diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index e9565d881b..55ed138376 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -69,7 +69,15 @@ class SimulatorRunner(FLComponent): def __init__( - self, job_folder: str, workspace: str, clients=None, n_clients=None, threads=None, gpu=None, max_clients=100 + self, + job_folder: str, + workspace: str, + clients=None, + n_clients=None, + threads=None, + gpu=None, + max_clients=100, + end_run_for_all=False, ): super().__init__() @@ -80,6 +88,7 @@ def __init__( self.threads = threads self.gpu = gpu self.max_clients = max_clients + self.end_run_for_all = end_run_for_all self.ask_to_stop = False @@ -142,6 +151,7 @@ def setup(self): self.args.env = os.path.join("config", AppFolderConstants.CONFIG_ENV) cwd = os.getcwd() self.args.job_folder = os.path.join(cwd, self.args.job_folder) + self.args.end_run_for_all = self.end_run_for_all if not os.path.exists(self.args.workspace): os.makedirs(self.args.workspace) @@ -523,7 +533,7 @@ def __init__(self, args, clients: [], client_config, deploy_args, build_ctx): self.kv_list = parse_vars(args.set) self.logging_config = os.path.join(self.args.workspace, "local", WorkspaceConstants.LOGGING_CONFIG) - self.end_run_clients = [] + self.clients_finished_end_run = [] def run(self, gpu): try: @@ -533,17 +543,14 @@ def run(self, gpu): lock = threading.Lock() timeout = self.kv_list.get("simulator_worker_timeout", 60.0) for i in range(self.args.threads): - executor.submit(lambda p: self.run_client_thread(*p), [self.args.threads, gpu, lock, i, timeout]) + executor.submit( + lambda p: self.run_client_thread(*p), + [self.args.threads, gpu, lock, self.args.end_run_for_all, timeout], + ) # wait for the server and client running thread to finish. executor.shutdown() - for client in self.federated_clients: - if client.client_name not in self.end_run_clients: - self.do_one_task( - client, self.args.threads, gpu, lock, timeout=timeout, task_name=RunnerTask.END_RUN - ) - except Exception as e: self.logger.error(f"SimulatorClientRunner run error: {secure_format_exception(e)}") finally: @@ -562,7 +569,7 @@ def _shutdown_client(self, client): # Ignore the exception for the simulator client shutdown self.logger.warn(f"Exception happened to client{client.name} during shutdown ") - def run_client_thread(self, num_of_threads, gpu, lock, rank, timeout=60): + def run_client_thread(self, num_of_threads, gpu, lock, end_run_for_all, timeout=60): stop_run = False interval = 1 client_to_run = None # indicates the next client to run @@ -582,12 +589,43 @@ def run_client_thread(self, num_of_threads, gpu, lock, rank, timeout=60): ) if end_run_client: with lock: - self.end_run_clients.append(end_run_client) + self.clients_finished_end_run.append(end_run_client) client.simulate_running = False + + if end_run_for_all: + self._end_run_clients(gpu, lock, num_of_threads, timeout) except Exception as e: self.logger.error(f"run_client_thread error: {secure_format_exception(e)}") + def _end_run_clients(self, gpu, lock, num_of_threads, timeout): + """After the WF reaches the END_RUN, each running thread will try to pick up one of the remaining client + which has not run the END_RUN yet, then execute the END_RUN handler, until all the clients have done so. + These client END_RUN event handler only execute when "end_run_for_all" has been set. + + Multiple client running threads will try to pick up the client from the same clients pool. + + """ + # Each thread only stop picking up the NOT-DONE client until all clients have run the END_RUN event. + while len(self.clients_finished_end_run) != len(self.federated_clients): + with lock: + end_run_client = self._pick_next_client() + if end_run_client: + self.do_one_task( + end_run_client, num_of_threads, gpu, lock, timeout=timeout, task_name=RunnerTask.END_RUN + ) + with lock: + end_run_client.simulate_running = False + + def _pick_next_client(self): + for client in self.federated_clients: + # Ensure the client has not run the END_RUN event + if client.client_name not in self.clients_finished_end_run and not client.simulate_running: + client.simulate_running = True + self.clients_finished_end_run.append(client.client_name) + return client + return None + def do_one_task(self, client, num_of_threads, gpu, lock, timeout=60.0, task_name=RunnerTask.TASK_EXEC): open_port = get_open_ports(1)[0] client_workspace = os.path.join(self.args.workspace, client.client_name)