diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index c7113c13f53d..d37a970c5e3a 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -15,6 +15,7 @@ import base64 import time import signal +import psutil from collections import defaultdict from argparse import ArgumentParser, REMAINDER @@ -88,6 +89,21 @@ def parse_args(): return parser.parse_args() +# Adapted from https://psutil.readthedocs.io/en/latest/#kill-process-tree +def terminate_process_tree(pid): + process = psutil.Process(pid) + children = process.children(recursive=True) + children.append(process) + for child in children: + try: + child.terminate() + except psutil.NoSuchProcess: + pass + gone, alive = psutil.wait_procs(children, timeout=30) + for p in alive: + p.kill() + + def main(): args = parse_args() current_env = os.environ.copy() @@ -189,7 +205,7 @@ def sigkill_handler(signum, frame): for process in processes: logger.info(f"Killing subprocess {process.pid}") try: - process.kill() + terminate_process_tree(process.pid) except Exception: pass if last_return_code is not None: