Skip to content
30 changes: 30 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import tempfile
import threading
import time
import traceback
import warnings
from contextlib import contextmanager
from functools import lru_cache
Expand Down Expand Up @@ -1766,3 +1767,32 @@ def parse_connector_type(url: str) -> str:
return ""

return m.group(1)


def retry(
fn,
max_retry: int,
initial_delay: float = 2.0,
max_delay: float = 60.0,
should_retry: Callable[[Any], bool] = lambda e: True,
):
for try_index in itertools.count():
try:
return fn()
except Exception as e:
if try_index >= max_retry:
raise Exception(f"retry() exceed maximum number of retries.")

if not should_retry(e):
raise Exception(f"retry() observe errors that should not be retried.")

delay = min(initial_delay * (2**try_index), max_delay) * (
0.75 + 0.25 * random.random()
)

logger.warning(
f"retry() failed once ({try_index}th try, maximum {max_retry} retries). Will delay {delay:.2f}s and retry. Error: {e}"
)
traceback.print_exc()

time.sleep(delay)
28 changes: 6 additions & 22 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.utils import get_bool_env_var, kill_process_tree
from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry
from sglang.test.run_eval import run_eval
from sglang.utils import get_exception_traceback

Expand Down Expand Up @@ -1010,26 +1010,10 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):

class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method):
_retry_execution(
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=_get_max_retry(),
max_retry = int(
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
)


def _get_max_retry():
return int(os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0"))


def _retry_execution(fn, max_retry: int):
if max_retry == 0:
fn()
return

try:
fn()
except Exception as e:
print(
f"retry_execution failed once and will retry. This may be an error or a flaky test. Error: {e}"
retry(
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry,
)
traceback.print_exc()
_retry_execution(fn, max_retry=max_retry - 1)
Loading