Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import logging
import multiprocessing
import os
Expand Down Expand Up @@ -56,7 +57,14 @@ def wrapper(self, *args, **kwargs):
class Hf3fsUsrBioClient(Hf3fsClient):
"""HF3FS client implementation using usrbio."""

def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
def __init__(
self,
path: str,
size: int,
bytes_per_page: int,
entries: int,
client_timeout: int,
):
if not HF3FS_AVAILABLE:
raise ImportError(
"hf3fs_fuse.io is not available. Please install the hf3fs_fuse package."
Expand All @@ -66,6 +74,7 @@ def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
self.size = size
self.bytes_per_page = bytes_per_page
self.entries = entries
self.client_timeout = client_timeout

self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
os.ftruncate(self.file, size)
Expand Down Expand Up @@ -121,7 +130,9 @@ def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[in

# submit
ionum = len(offsets)
resv = self.ior_r.submit().wait(min_results=ionum)
resv = self.ior_r.submit().wait(
min_results=ionum, timeout=datetime.timedelta(seconds=self.client_timeout)
)

# results
hf3fs_utils.read_shm(self.shm_r_tensor, tensors)
Expand All @@ -145,7 +156,9 @@ def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[i

# submit
ionum = len(offsets)
resv = self.ior_w.submit().wait(min_results=ionum)
resv = self.ior_w.submit().wait(
min_results=ionum, timeout=datetime.timedelta(seconds=self.client_timeout)
)

# results
results = [res.result for res in resv]
Expand Down
17 changes: 12 additions & 5 deletions python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def wrapper(self, *args, **kwargs):


def create_hf3fs_client(
path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False
path: str,
size: int,
bytes_per_page: int,
entries: int,
client_timeout: int,
use_mock: bool = False,
) -> Hf3fsClient:
"""Factory function to create appropriate HF3FS client.

Expand All @@ -143,7 +148,7 @@ def create_hf3fs_client(
Hf3fsUsrBioClient,
)

return Hf3fsUsrBioClient(path, size, bytes_per_page, entries)
return Hf3fsUsrBioClient(path, size, bytes_per_page, entries, client_timeout)


class HiCacheHF3FS(HiCacheStorage):
Expand All @@ -159,6 +164,7 @@ def __init__(
numjobs: int,
bytes_per_page: int,
entries: int,
client_timeout: int,
dtype: torch.dtype,
metadata_client: Hf3fsMetadataInterface,
is_mla_model: bool = False,
Expand All @@ -172,6 +178,7 @@ def __init__(
self.bytes_per_page = bytes_per_page
self.gb_per_page = bytes_per_page / (1 << 30)
self.entries = entries
self.client_timeout = client_timeout
self.dtype = dtype
self.metadata_client = metadata_client
self.is_mla_model = is_mla_model
Expand Down Expand Up @@ -200,6 +207,7 @@ def __init__(
self.file_size,
self.bytes_per_page,
self.entries,
self.client_timeout,
use_mock_client,
)
for _ in range(numjobs)
Expand Down Expand Up @@ -275,6 +283,7 @@ def from_env_config(
numjobs=16,
bytes_per_page=bytes_per_page,
entries=8,
client_timeout=5,
dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
is_page_first_layout=is_page_first_layout,
Expand Down Expand Up @@ -324,14 +333,14 @@ def from_env_config(
numjobs=int(config["numjobs"]),
bytes_per_page=bytes_per_page,
entries=int(config["entries"]),
client_timeout=config.get("client_timeout", 5),
dtype=dtype,
metadata_client=metadata_client,
is_mla_model=is_mla_model,
is_page_first_layout=is_page_first_layout,
use_mock_client=use_mock_client,
)

@synchronized()
def _batch_get(
self,
keys: List[str],
Expand Down Expand Up @@ -379,7 +388,6 @@ def _batch_get(

return results

@synchronized()
def _batch_set(
self,
keys: List[str],
Expand Down Expand Up @@ -486,7 +494,6 @@ def close(self) -> None:
logger.error(f"close HiCacheHF3FS: {e}")
logger.info("close HiCacheHF3FS")

@synchronized()
def get_stats(self):
storage_metrics = StorageMetrics()
storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
Expand Down
Loading