diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f325050c74b..4d7a2e532a6 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -35,6 +35,7 @@ import tempfile import threading import time +import traceback import warnings from contextlib import contextmanager from functools import lru_cache @@ -1766,3 +1767,32 @@ def parse_connector_type(url: str) -> str: return "" return m.group(1) + + +def retry( + fn, + max_retry: int, + initial_delay: float = 2.0, + max_delay: float = 60.0, + should_retry: Callable[[Any], bool] = lambda e: True, +): + for try_index in itertools.count(): + try: + return fn() + except Exception as e: + if try_index >= max_retry: + raise Exception(f"retry() exceed maximum number of retries.") + + if not should_retry(e): + raise Exception(f"retry() observe errors that should not be retried.") + + delay = min(initial_delay * (2**try_index), max_delay) * ( + 0.75 + 0.25 * random.random() + ) + + logger.warning( + f"retry() failed once ({try_index}th try, maximum {max_retry} retries). Will delay {delay:.2f}s and retry. Error: {e}" + ) + traceback.print_exc() + + time.sleep(delay) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 39552521ee3..6eb5c663cd7 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -25,7 +25,7 @@ from sglang.global_config import global_config from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.utils import get_bool_env_var, kill_process_tree +from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry from sglang.test.run_eval import run_eval from sglang.utils import get_exception_traceback @@ -1010,26 +1010,10 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple): class CustomTestCase(unittest.TestCase): def _callTestMethod(self, method): - _retry_execution( - lambda: super(CustomTestCase, self)._callTestMethod(method), - max_retry=_get_max_retry(), + max_retry = int( + os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0") ) - - -def _get_max_retry(): - return int(os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")) - - -def _retry_execution(fn, max_retry: int): - if max_retry == 0: - fn() - return - - try: - fn() - except Exception as e: - print( - f"retry_execution failed once and will retry. This may be an error or a flaky test. Error: {e}" + retry( + lambda: super(CustomTestCase, self)._callTestMethod(method), + max_retry=max_retry, ) - traceback.print_exc() - _retry_execution(fn, max_retry=max_retry - 1)