-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Description
Your current environment
The output of python collect_env.py
Collecting environment information...
==============================
System Info
==============================
OS : Ubuntu 22.04.5 LTS (x86_64)
GCC version : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version : Could not collect
CMake version : version 3.29.2
Libc version : glibc-2.35
==============================
PyTorch Info
==============================
PyTorch version : 2.7.0+cu126
Is debug build : False
CUDA used to build PyTorch : 12.6
ROCM used to build PyTorch : N/A
==============================
Python Environment
==============================
Python version : 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] (64-bit runtime)
Python platform : Linux-6.8.0-1030-gcp-x86_64-with-glibc2.35
==============================
CUDA / GPU Info
==============================
Is CUDA available : True
CUDA runtime version : Could not collect
CUDA_MODULE_LOADING set to : LAZY
GPU models and configuration :
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
Nvidia driver version : 570.133.07
cuDNN version : Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.8.0
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True
==============================
CPU Info
==============================
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 52
On-line CPU(s) list: 0-51
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 26
Socket(s): 1
Stepping: 8
BogoMIPS: 5399.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.2 MiB (26 instances)
L1i cache: 832 KiB (26 instances)
L2 cache: 52 MiB (26 instances)
L3 cache: 105 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-51
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
==============================
Versions of relevant libraries
==============================
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-cufile-cu12==1.11.1.6
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pyzmq==27.0.0
[pip3] torch==2.7.0
[pip3] torchaudio==2.7.0
[pip3] torchvision==0.22.0
[pip3] transformers==4.53.2
[pip3] triton==3.3.0
[conda] Could not collect
==============================
vLLM Info
==============================
ROCM Version : Could not collect
Neuron SDK Version : N/A
vLLM Version : 0.9.2
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 0-51 0 N/A
GPU1 NV18 X 0-51 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
==============================
Environment Variables
==============================
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY
🐛 Describe the bug
Models quantized with GPTQModel https://github.com/ModelCloud/GPTQModel/ , whether the latest published version on Pypi (2.2.0) or compiling from source (this commit 5d2911a4b2a709afb0941d53c3882d0cd80b9649 ) (4.0.0-dev) cannot be load in vllm 0.9.2 (but could be loaded on vllm 0.9.1)
I tried Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-32B-Instruct. I assume it would replicate on the 7B and other versions.
Command to start vllm (both 0.9.1 and 0.9.2):
python3 -m vllm.entrypoints.openai.api_server --model 3B --generation-config vllm --max-model-len 32768 -tp 1 --limit_mm_per_prompt '{"images": 6, "videos": 0}'Output of vllm 0.9.2
INFO 07-15 12:50:00 [__init__.py:244] Automatically detected platform cuda. INFO 07-15 12:50:03 [api_server.py:1395] vLLM API server version 0.9.2 INFO 07-15 12:50:03 [cli_args.py:325] non-default args: {'model': '../GPTQModel/3B', 'max_model_len': 32768, 'generation_config': 'vllm', 'limit_mm_per_prompt': {'images': 6, 'videos': 0}} INFO 07-15 12:50:07 [config.py:841] This model supports multiple tasks: {'reward', 'embed', 'classify', 'generate'}. Defaulting to 'generate'. INFO 07-15 12:50:07 [config.py:1472] Using max model len 32768 INFO 07-15 12:50:07 [gptq_marlin.py:170] The model is convertible to gptq_marlin during runtime. Using gptq_marlin kernel. INFO 07-15 12:50:08 [config.py:2285] Chunked prefill is enabled with max_num_batched_tokens=8192. INFO 07-15 12:50:11 [__init__.py:244] Automatically detected platform cuda. INFO 07-15 12:50:13 [core.py:526] Waiting for init message from front-end. INFO 07-15 12:50:13 [core.py:69] Initializing a V1 LLM engine (v0.9.2) with config: model='../GPTQModel/3B', speculative_config=None, tokenizer='../GPTQModel/3B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=gptq_marlin, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=../GPTQModel/3B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null} INFO 07-15 12:50:19 [parallel_state.py:1076] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0 WARNING 07-15 12:50:19 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer. INFO 07-15 12:50:19 [gpu_model_runner.py:1770] Starting to load model ../GPTQModel/3B... INFO 07-15 12:50:19 [gpu_model_runner.py:1775] Loading model from scratch... WARNING 07-15 12:50:19 [vision.py:91] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend. INFO 07-15 12:50:20 [gptq_marlin.py:266] Using MacheteLinearKernel for GPTQMarlinLinearMethod INFO 07-15 12:50:20 [cuda.py:284] Using Flash Attention backend on V1 engine. Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00INFO 07-15 12:50:20 [default_loader.py:272] Loading weights took 0.69 secondsINFO 07-15 12:50:21 [gpu_model_runner.py:1801] Model loading took 3.3225 GiB and 1.205221 seconds
INFO 07-15 12:50:21 [gpu_model_runner.py:2238] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 1 image items of the maximum feature size.
Using a slow image processor as
use_fast is unset and a slow processor was saved with this model. use_fast=True will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with use_fast=False.ERROR 07-15 12:50:26 [core.py:586] EngineCore failed to start.
ERROR 07-15 12:50:26 [core.py:586] Traceback (most recent call last):
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 577, in run_engine_core
ERROR 07-15 12:50:26 [core.py:586] engine_core = EngineCoreProc(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 404, in init
ERROR 07-15 12:50:26 [core.py:586] super().init(vllm_config, executor_class, log_stats,
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 82, in init
ERROR 07-15 12:50:26 [core.py:586] self._initialize_kv_caches(vllm_config)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
ERROR 07-15 12:50:26 [core.py:586] available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
ERROR 07-15 12:50:26 [core.py:586] output = self.collective_rpc("determine_available_memory")
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 07-15 12:50:26 [core.py:586] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/utils/init.py", line 2736, in run_method
ERROR 07-15 12:50:26 [core.py:586] return func(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 07-15 12:50:26 [core.py:586] return func(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_worker.py", line 210, in determine_available_memory
ERROR 07-15 12:50:26 [core.py:586] self.model_runner.profile_run()
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2274, in profile_run
ERROR 07-15 12:50:26 [core.py:586] = self._dummy_run(self.max_num_tokens, is_profile=True)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 07-15 12:50:26 [core.py:586] return func(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2057, in _dummy_run
ERROR 07-15 12:50:26 [core.py:586] outputs = model(
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 07-15 12:50:26 [core.py:586] return self._call_impl(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 07-15 12:50:26 [core.py:586] return forward_call(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2_5_vl.py", line 1145, in forward
ERROR 07-15 12:50:26 [core.py:586] hidden_states = self.language_model.model(
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 239, in call
ERROR 07-15 12:50:26 [core.py:586] output = self.compiled_callable(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
ERROR 07-15 12:50:26 [core.py:586] raise e.with_traceback(None) from None
ERROR 07-15 12:50:26 [core.py:586] torch._dynamo.exc.Unsupported: All torch_function overrides for call TorchInGraphFunctionVariable(C.machete_mm, nonstrict_traceable=False) with args [TensorVariable(), TensorVariable(), LazyVariableTracker(), LazyVariableTracker(), TensorVariable(), UserDefinedObjectVariable(PackedColumnParameter), LazyVariableTracker(), LazyVariableTracker(), LazyVariableTracker(), LazyVariableTracker()] and kwargs {} returned NotImplemented
ERROR 07-15 12:50:26 [core.py:586]
ERROR 07-15 12:50:26 [core.py:586] from user code:
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 355, in forward
ERROR 07-15 12:50:26 [core.py:586] hidden_states, residual = layer(
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 254, in forward
ERROR 07-15 12:50:26 [core.py:586] hidden_states = self.self_attn(
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 181, in forward
ERROR 07-15 12:50:26 [core.py:586] qkv, _ = self.qkv_proj(hidden_states)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 510, in forward
ERROR 07-15 12:50:26 [core.py:586] output_parallel = self.quant_method.apply(self, input, bias)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/gptq_marlin.py", line 372, in apply
ERROR 07-15 12:50:26 [core.py:586] return self.kernel.apply_weights(layer, x, bias)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py", line 122, in apply_weights
ERROR 07-15 12:50:26 [core.py:586] output = ops.machete_mm(a=x_2d,
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/_custom_ops.py", line 1098, in machete_mm
ERROR 07-15 12:50:26 [core.py:586] return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
ERROR 07-15 12:50:26 [core.py:586]
ERROR 07-15 12:50:26 [core.py:586] Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
ERROR 07-15 12:50:26 [core.py:586]
Process EngineCore_0:
Traceback (most recent call last):
File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 590, in run_engine_core
raise e
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 577, in run_engine_core
engine_core = EngineCoreProc(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 404, in init
super().init(vllm_config, executor_class, log_stats,
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 82, in init
self._initialize_kv_caches(vllm_config)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
available_gpu_memory = self.model_executor.determine_available_memory()
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
output = self.collective_rpc("determine_available_memory")
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
answer = run_method(self.driver_worker, method, args, kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/utils/init.py", line 2736, in run_method
return func(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_worker.py", line 210, in determine_available_memory
self.model_runner.profile_run()
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2274, in profile_run
= self._dummy_run(self.max_num_tokens, is_profile=True)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2057, in _dummy_run
outputs = model(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2_5_vl.py", line 1145, in forward
hidden_states = self.language_model.model(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 239, in call
output = self.compiled_callable(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
raise e.with_traceback(None) from None
torch._dynamo.exc.Unsupported: All torch_function overrides for call TorchInGraphFunctionVariable(_C.machete_mm, nonstrict_traceable=False) with args [TensorVariable(), TensorVariable(), LazyVariableTracker(), LazyVariableTracker(), TensorVariable(), UserDefinedObjectVariable(PackedColumnParameter), LazyVariableTracker(), LazyVariableTracker(), LazyVariableTracker(), LazyVariableTracker()] and kwargs {} returned NotImplemented
from user code:
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 355, in forward
hidden_states, residual = layer(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 254, in forward
hidden_states = self.self_attn(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 181, in forward
qkv, _ = self.qkv_proj(hidden_states)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 510, in forward
output_parallel = self.quant_method.apply(self, input_, bias)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/gptq_marlin.py", line 372, in apply
return self.kernel.apply_weights(layer, x, bias)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py", line 122, in apply_weights
output = ops.machete_mm(a=x_2d,
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/_custom_ops.py", line 1098, in machete_mm
return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
[rank0]:[W715 12:50:26.177353314 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 1495, in
uvloop.run(run_server(args))
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/uvloop/init.py", line 82, in run
return loop.run_until_complete(wrapper())
File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/uvloop/init.py", line 61, in wrapper
return await main
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 1431, in run_server
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 1451, in run_server_worker
async with build_async_engine_client(args, client_config) as engine_client:
File "/usr/lib/python3.10/contextlib.py", line 199, in aenter
return await anext(self.gen)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 158, in build_async_engine_client
async with build_async_engine_client_from_engine_args(
File "/usr/lib/python3.10/contextlib.py", line 199, in aenter
return await anext(self.gen)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 194, in build_async_engine_client_from_engine_args
async_llm = AsyncLLM.from_vllm_config(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/async_llm.py", line 162, in from_vllm_config
return cls(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/async_llm.py", line 124, in init
self.engine_core = EngineCoreClient.make_async_mp_client(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core_client.py", line 96, in make_async_mp_client
return AsyncMPClient(*client_args)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core_client.py", line 666, in init
super().init(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core_client.py", line 403, in init
with launch_core_engines(vllm_config, executor_class,
File "/usr/lib/python3.10/contextlib.py", line 142, in exit
next(self.gen)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/utils.py", line 434, in launch_core_engines
wait_for_engine_startup(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/utils.py", line 484, in wait_for_engine_startup
raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}
Output of vllm 0.9.1 (no bug)
INFO 07-15 12:47:01 [__init__.py:244] Automatically detected platform cuda. INFO 07-15 12:47:04 [api_server.py:1287] vLLM API server version 0.9.1 INFO 07-15 12:47:05 [cli_args.py:309] non-default args: {'model': '../GPTQModel/3B', 'max_model_len': 32768, 'generation_config': 'vllm', 'limit_mm_per_prompt': {'images': 6, 'videos': 0}} INFO 07-15 12:47:12 [config.py:823] This model supports multiple tasks: {'score', 'classify', 'embed', 'generate', 'reward'}. Defaulting to 'generate'. INFO 07-15 12:47:13 [gptq_marlin.py:145] The model is convertible to gptq_marlin during runtime. Using gptq_marlin kernel. INFO 07-15 12:47:13 [config.py:2195] Chunked prefill is enabled with max_num_batched_tokens=8192. WARNING 07-15 12:47:15 [env_override.py:17] NCCL_CUMEM_ENABLE is set to 0, skipping override. This may increase memory overhead with cudagraph+allreduce: https://github.com/NVIDIA/nccl/issues/1234 INFO 07-15 12:47:17 [__init__.py:244] Automatically detected platform cuda. INFO 07-15 12:47:20 [core.py:455] Waiting for init message from front-end. INFO 07-15 12:47:20 [core.py:70] Initializing a V1 LLM engine (v0.9.1) with config: model='../GPTQModel/3B', speculative_config=None, tokenizer='../GPTQModel/3B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=gptq_marlin, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=../GPTQModel/3B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null} WARNING 07-15 12:47:20 [utils.py:2737] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in INFO 07-15 12:47:26 [parallel_state.py:1065] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0 Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`. WARNING 07-15 12:47:31 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer. INFO 07-15 12:47:31 [gpu_model_runner.py:1595] Starting to load model ../GPTQModel/3B... INFO 07-15 12:47:31 [gpu_model_runner.py:1600] Loading model from scratch... WARNING 07-15 12:47:31 [vision.py:91] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend. INFO 07-15 12:47:31 [gptq_marlin.py:240] Using MacheteLinearKernel for GPTQMarlinLinearMethod INFO 07-15 12:47:32 [cuda.py:252] Using Flash Attention backend on V1 engine. Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00INFO 07-15 12:47:32 [default_loader.py:272] Loading weights took 0.75 secondsINFO 07-15 12:47:33 [gpu_model_runner.py:1624] Model loading took 3.3225 GiB and 1.250191 seconds
INFO 07-15 12:47:35 [gpu_model_runner.py:1978] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 1 image items of the maximum feature size.
INFO 07-15 12:47:46 [backends.py:462] Using cache directory: /home/soren/.cache/vllm/torch_compile_cache/03f03888ae/rank_0_0 for vLLM's torch.compile
INFO 07-15 12:47:46 [backends.py:472] Dynamo bytecode transform time: 8.82 s
INFO 07-15 12:47:49 [backends.py:161] Cache the graph of shape None for later use
INFO 07-15 12:48:18 [backends.py:173] Compiling a graph for general shape takes 31.69 s
INFO 07-15 12:48:41 [monitor.py:34] torch.compile takes 40.52 s in total
INFO 07-15 12:48:42 [gpu_worker.py:227] Available KV cache memory: 61.88 GiB
INFO 07-15 12:48:43 [kv_cache_utils.py:715] GPU KV cache size: 1,802,496 tokens
INFO 07-15 12:48:43 [kv_cache_utils.py:719] Maximum concurrency for 32,768 tokens per request: 55.01x
INFO 07-15 12:49:14 [gpu_model_runner.py:2048] Graph capturing finished in 31 secs, took 0.80 GiB
INFO 07-15 12:49:14 [core.py:171] init engine (profile, create kv cache, warmup model) took 101.19 seconds
INFO 07-15 12:49:16 [loggers.py:137] Engine 000: vllm cache_config_info with initialization after num_gpu_blocks is: 112656
INFO 07-15 12:49:17 [api_server.py:1349] Starting vLLM API server 0 on http://0.0.0.0:8000
INFO 07-15 12:49:17 [launcher.py:29] Available routes are:
INFO 07-15 12:49:17 [launcher.py:37] Route: /openapi.json, Methods: HEAD, GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /docs, Methods: HEAD, GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /docs/oauth2-redirect, Methods: HEAD, GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /redoc, Methods: HEAD, GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /health, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /load, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /ping, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /ping, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /tokenize, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /detokenize, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/models, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /version, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/chat/completions, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/completions, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/embeddings, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /pooling, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /classify, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /score, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/score, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/audio/transcriptions, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /rerank, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/rerank, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v2/rerank, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /invocations, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /metrics, Methods: GET
INFO: Started server process [67381]
INFO: Waiting for application startup.
INFO: Application startup complete.
For convenience, I have uploaded the quantized model here:
https://huggingface.co/NM-dev/Qwen2.5-VL-3B-Instruct-GPTQModel-W4A16-G128
If you want to reproduce the model and quantize it yourself:
#!/usr/bin/env python
# coding: utf-8
import os
import random
import json
import io
import base64
import types
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.models.base import BaseGPTQModel
from transformers import AutoProcessor
from PIL import Image
SEED = 42
model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
quant_path = f"3B"
random.seed(SEED)
quant_config = QuantizeConfig(bits=4, group_size=128)
model = GPTQModel.load(model_path, quant_config)
processor = AutoProcessor.from_pretrained(model_path)
calibration_dataset = load_dataset(
"HuggingFaceM4/ChartQA",
split="train"
).select(range(256))
calib_data = []
for sample in calibration_dataset:
try:
messages = []
messages.append(
{
"role": "user",
"content": [
{
"type": "image",
"image_url": sample["image"]
},
{
"type": "text", "text": sample["query"],
},
],
}
)
messages.append(
{
"role": "assistant",
"content": sample["label"][0],
},
)
text = processor.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=text,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
calib_data.append(inputs)
except Exception as e:
print(e)
pass
model.prepare_dataset = lambda *args, **kwargs: BaseGPTQModel.prepare_dataset(model, *args, **kwargs)
model.quantize(calib_data, batch_size=1)
model.save(quant_path)
processor.save_pretrained(quant_path)Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.