diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_usrbio_client.py b/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_usrbio_client.py index 480c18ed1c6f..253219826a8e 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_usrbio_client.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_usrbio_client.py @@ -1,3 +1,4 @@ +import datetime import logging import multiprocessing import os @@ -56,7 +57,14 @@ def wrapper(self, *args, **kwargs): class Hf3fsUsrBioClient(Hf3fsClient): """HF3FS client implementation using usrbio.""" - def __init__(self, path: str, size: int, bytes_per_page: int, entries: int): + def __init__( + self, + path: str, + size: int, + bytes_per_page: int, + entries: int, + client_timeout: int, + ): if not HF3FS_AVAILABLE: raise ImportError( "hf3fs_fuse.io is not available. Please install the hf3fs_fuse package." @@ -66,6 +74,7 @@ def __init__(self, path: str, size: int, bytes_per_page: int, entries: int): self.size = size self.bytes_per_page = bytes_per_page self.entries = entries + self.client_timeout = client_timeout self.file = os.open(self.path, os.O_RDWR | os.O_CREAT) os.ftruncate(self.file, size) @@ -121,7 +130,9 @@ def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[in # submit ionum = len(offsets) - resv = self.ior_r.submit().wait(min_results=ionum) + resv = self.ior_r.submit().wait( + min_results=ionum, timeout=datetime.timedelta(seconds=self.client_timeout) + ) # results hf3fs_utils.read_shm(self.shm_r_tensor, tensors) @@ -145,7 +156,9 @@ def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[i # submit ionum = len(offsets) - resv = self.ior_w.submit().wait(min_results=ionum) + resv = self.ior_w.submit().wait( + min_results=ionum, timeout=datetime.timedelta(seconds=self.client_timeout) + ) # results results = [res.result for res in resv] diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index a789c2af8214..55d34dc6291d 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -120,7 +120,12 @@ def wrapper(self, *args, **kwargs): def create_hf3fs_client( - path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False + path: str, + size: int, + bytes_per_page: int, + entries: int, + client_timeout: int, + use_mock: bool = False, ) -> Hf3fsClient: """Factory function to create appropriate HF3FS client. @@ -143,7 +148,7 @@ def create_hf3fs_client( Hf3fsUsrBioClient, ) - return Hf3fsUsrBioClient(path, size, bytes_per_page, entries) + return Hf3fsUsrBioClient(path, size, bytes_per_page, entries, client_timeout) class HiCacheHF3FS(HiCacheStorage): @@ -159,6 +164,7 @@ def __init__( numjobs: int, bytes_per_page: int, entries: int, + client_timeout: int, dtype: torch.dtype, metadata_client: Hf3fsMetadataInterface, is_mla_model: bool = False, @@ -172,6 +178,7 @@ def __init__( self.bytes_per_page = bytes_per_page self.gb_per_page = bytes_per_page / (1 << 30) self.entries = entries + self.client_timeout = client_timeout self.dtype = dtype self.metadata_client = metadata_client self.is_mla_model = is_mla_model @@ -200,6 +207,7 @@ def __init__( self.file_size, self.bytes_per_page, self.entries, + self.client_timeout, use_mock_client, ) for _ in range(numjobs) @@ -275,6 +283,7 @@ def from_env_config( numjobs=16, bytes_per_page=bytes_per_page, entries=8, + client_timeout=5, dtype=dtype, metadata_client=Hf3fsLocalMetadataClient(), is_page_first_layout=is_page_first_layout, @@ -324,6 +333,7 @@ def from_env_config( numjobs=int(config["numjobs"]), bytes_per_page=bytes_per_page, entries=int(config["entries"]), + client_timeout=config.get("client_timeout", 5), dtype=dtype, metadata_client=metadata_client, is_mla_model=is_mla_model, @@ -331,7 +341,6 @@ def from_env_config( use_mock_client=use_mock_client, ) - @synchronized() def _batch_get( self, keys: List[str], @@ -379,7 +388,6 @@ def _batch_get( return results - @synchronized() def _batch_set( self, keys: List[str], @@ -486,7 +494,6 @@ def close(self) -> None: logger.error(f"close HiCacheHF3FS: {e}") logger.info("close HiCacheHF3FS") - @synchronized() def get_stats(self): storage_metrics = StorageMetrics() storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)