From 73fc7c7216098b41d5515a4d0856419064f50a2b Mon Sep 17 00:00:00 2001
From: arlo <264998716+arlo-aisys@users.noreply.github.com>
Date: Thu, 5 Mar 2026 20:11:29 +0800
Subject: [PATCH] [Feature] Add support for InstantTensor
Signed-off-by: arlo <264998716+arlo-aisys@users.noreply.github.com>
---
docker/Dockerfile.cpu | 4 +-
docs/models/extensions/instanttensor.md | 31 +++++++++++
requirements/nightly_torch_test.txt | 1 +
requirements/test.in | 1 +
requirements/test.txt | 3 ++
setup.py | 1 +
.../instanttensor_loader/__init__.py | 0
.../test_instanttensor_loader.py | 28 ++++++++++
.../instanttensor_loader/test_weight_utils.py | 52 +++++++++++++++++++
vllm/config/load.py | 5 +-
vllm/model_executor/model_loader/__init__.py | 2 +
.../model_loader/default_loader.py | 12 ++++-
.../model_loader/weight_utils.py | 42 ++++++++++++++-
13 files changed, 177 insertions(+), 5 deletions(-)
create mode 100644 docs/models/extensions/instanttensor.md
create mode 100644 tests/model_executor/model_loader/instanttensor_loader/__init__.py
create mode 100644 tests/model_executor/model_loader/instanttensor_loader/test_instanttensor_loader.py
create mode 100644 tests/model_executor/model_loader/instanttensor_loader/test_weight_utils.py
diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu
index d81957e02d19..6f020e30ed47 100644
--- a/docker/Dockerfile.cpu
+++ b/docker/Dockerfile.cpu
@@ -36,7 +36,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -y \
&& apt-get install -y --no-install-recommends sudo ccache git curl wget ca-certificates \
- gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof \
+ gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof make \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
@@ -174,7 +174,7 @@ WORKDIR /vllm-workspace
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
- apt-get install -y --no-install-recommends vim numactl xz-utils make clangd-14
+ apt-get install -y --no-install-recommends vim numactl xz-utils clangd-14
RUN ln -s /usr/bin/clangd-14 /usr/bin/clangd
diff --git a/docs/models/extensions/instanttensor.md b/docs/models/extensions/instanttensor.md
new file mode 100644
index 000000000000..0ac7094cefb9
--- /dev/null
+++ b/docs/models/extensions/instanttensor.md
@@ -0,0 +1,31 @@
+# Loading Model Weights with InstantTensor
+
+InstantTensor accelerates loading Safetensors weights on CUDA devices through distributed loading, pipelined prefetching, and direct I/O. InstantTensor also supports GDS (GPUDirect Storage) when available.
+For more details, see the [InstantTensor GitHub repository](https://github.com/scitix/InstantTensor).
+
+## Installation
+
+```bash
+pip install instanttensor
+```
+
+## Use InstantTensor in vLLM
+
+Add `--load-format instanttensor` as a command-line argument.
+
+For example:
+
+```bash
+vllm serve Qwen/Qwen2.5-0.5B --load-format instanttensor
+```
+
+## Benchmarks
+
+| Model | GPU | Backend | Load Time (s) | Throughput (GB/s) | Speedup |
+| --- | ---: | --- | ---: | ---: | --- |
+| Qwen3-30B-A3B | 1*H200 | Safetensors | 57.4 | 1.1 | 1x |
+| Qwen3-30B-A3B | 1*H200 | InstantTensor | 1.77 | 35 | **32.4x** |
+| DeepSeek-R1 | 8*H200 | Safetensors | 160 | 4.3 | 1x |
+| DeepSeek-R1 | 8*H200 | InstantTensor | 15.3 | 45 | **10.5x** |
+
+For the full benchmark results, see .
diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt
index 27299f47ff4e..4d2bf8d2b100 100644
--- a/requirements/nightly_torch_test.txt
+++ b/requirements/nightly_torch_test.txt
@@ -44,4 +44,5 @@ numba == 0.61.2 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs]==0.15.3
fastsafetensors>=0.2.2
+instanttensor>=0.1.5
pydantic>=2.12 # 2.11 leads to error on python 3.13
diff --git a/requirements/test.in b/requirements/test.in
index 5e6e3256a725..295028ca8618 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -57,6 +57,7 @@ numba == 0.61.2 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs]==0.15.3
fastsafetensors>=0.2.2 # 0.2.2 contains important fixes for multi-GPU mem usage
+instanttensor>=0.1.5
pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0
terratorch >= 1.2.2 # Required for Prithvi tests
diff --git a/requirements/test.txt b/requirements/test.txt
index ac5fb9c2edff..2635fda4cdf3 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -375,6 +375,8 @@ inflect==5.6.2
# via datamodel-code-generator
iniconfig==2.0.0
# via pytest
+instanttensor==0.1.5
+ # via -r requirements/test.in
isoduration==20.11.0
# via jsonschema
isort==5.13.2
@@ -1169,6 +1171,7 @@ torch==2.10.0+cu129
# accelerate
# bitsandbytes
# encodec
+ # instanttensor
# kornia
# lightly
# lightning
diff --git a/setup.py b/setup.py
index fa13fff4e62e..9d5224beb531 100644
--- a/setup.py
+++ b/setup.py
@@ -968,6 +968,7 @@ def _read_requirements(filename: str) -> list[str]:
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
+ "instanttensor": ["instanttensor >= 0.1.5"],
"runai": ["runai-model-streamer[s3,gcs] >= 0.15.3"],
"audio": [
"librosa",
diff --git a/tests/model_executor/model_loader/instanttensor_loader/__init__.py b/tests/model_executor/model_loader/instanttensor_loader/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/model_executor/model_loader/instanttensor_loader/test_instanttensor_loader.py b/tests/model_executor/model_loader/instanttensor_loader/test_instanttensor_loader.py
new file mode 100644
index 000000000000..e9042305be23
--- /dev/null
+++ b/tests/model_executor/model_loader/instanttensor_loader/test_instanttensor_loader.py
@@ -0,0 +1,28 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+
+from vllm import SamplingParams
+from vllm.platforms import current_platform
+
+test_model = "openai-community/gpt2"
+
+prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+]
+# Create a sampling params object.
+sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
+
+
+@pytest.mark.skipif(
+ not current_platform.is_cuda(),
+ reason="InstantTensor requires NVIDIA GPUs",
+)
+def test_model_loader_download_files(vllm_runner):
+ with vllm_runner(test_model, load_format="instanttensor") as llm:
+ deserialized_outputs = llm.generate(prompts, sampling_params)
+ assert deserialized_outputs
diff --git a/tests/model_executor/model_loader/instanttensor_loader/test_weight_utils.py b/tests/model_executor/model_loader/instanttensor_loader/test_weight_utils.py
new file mode 100644
index 000000000000..992a83e0eea4
--- /dev/null
+++ b/tests/model_executor/model_loader/instanttensor_loader/test_weight_utils.py
@@ -0,0 +1,52 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import glob
+import tempfile
+
+import huggingface_hub.constants
+import pytest
+import torch
+
+from vllm.model_executor.model_loader.weight_utils import (
+ download_weights_from_hf,
+ instanttensor_weights_iterator,
+ safetensors_weights_iterator,
+)
+from vllm.platforms import current_platform
+
+
+@pytest.mark.skipif(
+ not current_platform.is_cuda(),
+ reason="InstantTensor requires NVIDIA GPUs",
+)
+def test_instanttensor_model_loader():
+ with tempfile.TemporaryDirectory() as tmpdir:
+ huggingface_hub.constants.HF_HUB_OFFLINE = False
+ download_weights_from_hf(
+ "openai-community/gpt2", allow_patterns=["*.safetensors"], cache_dir=tmpdir
+ )
+ safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
+ assert len(safetensors) > 0
+
+ instanttensor_tensors = {}
+ hf_safetensors_tensors = {}
+
+ for name, tensor in instanttensor_weights_iterator(safetensors, True):
+ # Copy the tensor immediately as it is a reference to the internal
+ # buffer of instanttensor.
+ instanttensor_tensors[name] = tensor.to("cpu")
+
+ for name, tensor in safetensors_weights_iterator(safetensors, True):
+ hf_safetensors_tensors[name] = tensor
+
+ assert len(instanttensor_tensors) == len(hf_safetensors_tensors)
+
+ for name, instanttensor_tensor in instanttensor_tensors.items():
+ assert instanttensor_tensor.dtype == hf_safetensors_tensors[name].dtype
+ assert instanttensor_tensor.shape == hf_safetensors_tensors[name].shape
+ assert torch.all(instanttensor_tensor.eq(hf_safetensors_tensors[name]))
+
+
+if __name__ == "__main__":
+ test_instanttensor_model_loader()
diff --git a/vllm/config/load.py b/vllm/config/load.py
index 64a269e9885a..b771556d83f2 100644
--- a/vllm/config/load.py
+++ b/vllm/config/load.py
@@ -29,6 +29,9 @@ class LoadConfig:
back to the pytorch bin format if safetensors format is not available.\n
- "pt" will load the weights in the pytorch bin format.\n
- "safetensors" will load the weights in the safetensors format.\n
+ - "instanttensor" will load the Safetensors weights on CUDA devices using
+ InstantTensor, which enables distributed loading with pipelined prefetching
+ and fast direct I/O.\n
- "npcache" will load the weights in pytorch format and store a numpy cache
to speed up the loading.\n
- "dummy" will initialize the weights with random values, which is mainly
@@ -46,7 +49,7 @@ class LoadConfig:
- "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
- "mistral" will load weights from consolidated safetensors files used by
- Mistral models.
+ Mistral models.\n
- Other custom values can be supported via plugins."""
download_dir: str | None = None
"""Directory to download and load the weights, default to the default
diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py
index ff95d5b945c6..53b6b3221b54 100644
--- a/vllm/model_executor/model_loader/__init__.py
+++ b/vllm/model_executor/model_loader/__init__.py
@@ -35,6 +35,7 @@
"dummy",
"fastsafetensors",
"gguf",
+ "instanttensor",
"mistral",
"npcache",
"pt",
@@ -51,6 +52,7 @@
"dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader,
"gguf": GGUFModelLoader,
+ "instanttensor": DefaultModelLoader,
"mistral": DefaultModelLoader,
"npcache": DefaultModelLoader,
"pt": DefaultModelLoader,
diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py
index 7064998af86b..2d5abcb1c157 100644
--- a/vllm/model_executor/model_loader/default_loader.py
+++ b/vllm/model_executor/model_loader/default_loader.py
@@ -23,6 +23,7 @@
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
get_quant_config,
+ instanttensor_weights_iterator,
maybe_download_from_modelscope,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
@@ -117,7 +118,11 @@ def _prepare_weights(
# Some quantized models use .pt files for storing the weights.
if load_format == "hf":
allow_patterns = ["*.safetensors", "*.bin"]
- elif load_format == "safetensors" or load_format == "fastsafetensors":
+ elif (
+ load_format == "safetensors"
+ or load_format == "fastsafetensors"
+ or load_format == "instanttensor"
+ ):
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "mistral":
@@ -209,6 +214,11 @@ def _get_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
+ elif self.load_config.load_format == "instanttensor":
+ weights_iterator = instanttensor_weights_iterator(
+ hf_weights_files,
+ self.load_config.use_tqdm_on_load,
+ )
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_safetensors_weights_iterator(
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index e00a17a153fb..c114c84aa456 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -29,7 +29,7 @@
from vllm import envs
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
-from vllm.distributed import get_tensor_model_parallel_rank
+from vllm.distributed import get_tensor_model_parallel_rank, get_world_group
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
@@ -897,6 +897,46 @@ def fastsafetensors_weights_iterator(
loader.close()
+def instanttensor_weights_iterator(
+ hf_weights_files: list[str],
+ use_tqdm_on_load: bool,
+) -> Generator[tuple[str, torch.Tensor], None, None]:
+ """Iterate over the weights in the model safetensor files
+ using instanttensor library."""
+ try:
+ import instanttensor
+ except ImportError as e:
+ raise ImportError(
+ "Please install instanttensor via `pip install instanttensor`"
+ ) from e
+
+ if not current_platform.is_cuda():
+ raise ValueError("InstantTensor requires NVIDIA GPUs")
+
+ try:
+ world_group = get_world_group()
+ except AssertionError:
+ # Entering here only in unit tests where the world group is not initialized.
+ process_group = None
+ else:
+ process_group = world_group.device_group if world_group.world_size > 1 else None
+
+ device = current_platform.current_device()
+
+ with instanttensor.safe_open(
+ hf_weights_files, framework="pt", device=device, process_group=process_group
+ ) as f:
+ yield from tqdm(
+ f.tensors(),
+ desc="Loading safetensors using InstantTensor loader",
+ disable=not enable_tqdm(use_tqdm_on_load),
+ bar_format=_BAR_FORMAT,
+ position=tqdm._get_free_pos(),
+ total=len(f.keys()),
+ mininterval=1.0,
+ )
+
+
def pt_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,