diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 4a7da77c98a1..12e6c0fad5e5 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -16,6 +16,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import datetime from functools import partial, wraps +from io import BytesIO from pathlib import Path from types import SimpleNamespace from typing import Any, Awaitable, Callable, List, Optional, Tuple @@ -25,6 +26,7 @@ import requests import torch import torch.nn.functional as F +from PIL import Image from sglang.bench_serving import run_benchmark from sglang.global_config import global_config @@ -129,6 +131,22 @@ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 +def download_image_with_retry(image_url: str, max_retries: int = 3) -> Image.Image: + for i in range(max_retries): + try: + response = requests.get(image_url, timeout=30) + response.raise_for_status() + image = Image.open(BytesIO(response.content)) + image.load() + return image + except Exception as e: + if i == max_retries - 1: + raise RuntimeError( + f"Failed to download image after {max_retries} retries: {image_url}" + ) from e + time.sleep(2**i) + + def is_in_ci(): """Return whether it is in CI runner.""" return get_bool_env_var("SGLANG_IS_IN_CI") diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index 02b1b40c6904..c65dadb42d36 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -5,10 +5,8 @@ import json import unittest -from io import BytesIO import requests -from PIL import Image from transformers import AutoProcessor, AutoTokenizer from sglang.lang.chat_template import get_chat_template_by_model_path @@ -20,6 +18,7 @@ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + download_image_with_retry, popen_launch_server, ) @@ -204,8 +203,7 @@ class TestSkipTokenizerInitVLM(TestSkipTokenizerInit): @classmethod def setUpClass(cls): cls.image_url = DEFAULT_IMAGE_URL - response = requests.get(cls.image_url) - cls.image = Image.open(BytesIO(response.content)) + cls.image = download_image_with_retry(cls.image_url) cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST cls.tokenizer = AutoTokenizer.from_pretrained(cls.model, use_fast=False) cls.processor = AutoProcessor.from_pretrained(cls.model, trust_remote_code=True) diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index 7f6c939f1d76..c722ed190fb0 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -2,14 +2,11 @@ """ import unittest -from io import BytesIO from typing import List, Optional import numpy as np -import requests import torch import torch.nn.functional as F -from PIL import Image from transformers import AutoModel, AutoProcessor, AutoTokenizer from sglang.srt.configs.model_config import ModelConfig @@ -24,6 +21,7 @@ from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.parser.conversation import generate_chat_conv from sglang.srt.server_args import ServerArgs +from sglang.test.test_utils import download_image_with_retry # Test the logits output between HF and SGLang @@ -35,8 +33,7 @@ def setUpClass(cls): cls.model_path = "" cls.chat_template = "" cls.processor = "" - response = requests.get(cls.image_url) - cls.main_image = Image.open(BytesIO(response.content)) + cls.main_image = download_image_with_retry(cls.image_url) def compare_outputs(self, sglang_output: torch.Tensor, hf_output: torch.Tensor): # Convert to float32 for numerical stability if needed diff --git a/test/srt/test_vlm_input_format.py b/test/srt/test_vlm_input_format.py index 18e4a6441d47..7105573dd6de 100644 --- a/test/srt/test_vlm_input_format.py +++ b/test/srt/test_vlm_input_format.py @@ -1,11 +1,8 @@ import json import unittest -from io import BytesIO from typing import Optional -import requests import torch -from PIL import Image from transformers import ( AutoProcessor, Gemma3ForConditionalGeneration, @@ -15,6 +12,7 @@ from sglang import Engine from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.parser.conversation import generate_chat_conv +from sglang.test.test_utils import download_image_with_retry TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true" @@ -31,8 +29,7 @@ def setUpClass(cls): assert cls.chat_template is not None, "Set chat_template in subclass" cls.image_url = TEST_IMAGE_URL cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - response = requests.get(cls.image_url) - cls.main_image = Image.open(BytesIO(response.content)) + cls.main_image = download_image_with_retry(cls.image_url) cls.processor = AutoProcessor.from_pretrained( cls.model_path, trust_remote_code=True, use_fast=True )