Skip to content

[Bug]: 0.9.2: Qwen2.5-VL GPTQ MacheteLinearKernel for GPTQMarlinLinearMethod: torch._dynamo.exc.Unsupported #20986

@SorenDreano

Description

@SorenDreano

Your current environment

The output of python collect_env.py
Collecting environment information...
==============================
        System Info
==============================
OS                           : Ubuntu 22.04.5 LTS (x86_64)
GCC version                  : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version                : Could not collect
CMake version                : version 3.29.2
Libc version                 : glibc-2.35

==============================
       PyTorch Info
==============================
PyTorch version              : 2.7.0+cu126
Is debug build               : False
CUDA used to build PyTorch   : 12.6
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] (64-bit runtime)
Python platform              : Linux-6.8.0-1030-gcp-x86_64-with-glibc2.35

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : Could not collect
CUDA_MODULE_LOADING set to   : LAZY
GPU models and configuration : 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3

Nvidia driver version        : 570.133.07
cuDNN version                : Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.8.0
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        52 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               52
On-line CPU(s) list:                  0-51
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
CPU family:                           6
Model:                                143
Thread(s) per core:                   2
Core(s) per socket:                   26
Socket(s):                            1
Stepping:                             8
BogoMIPS:                             5399.99
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            1.2 MiB (26 instances)
L1i cache:                            832 KiB (26 instances)
L2 cache:                             52 MiB (26 instances)
L3 cache:                             105 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-51
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

==============================
Versions of relevant libraries
==============================
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-cufile-cu12==1.11.1.6
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pyzmq==27.0.0
[pip3] torch==2.7.0
[pip3] torchaudio==2.7.0
[pip3] torchvision==0.22.0
[pip3] transformers==4.53.2
[pip3] triton==3.3.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
Neuron SDK Version           : N/A
vLLM Version                 : 0.9.2
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
        GPU0    GPU1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    0-51    0               N/A
GPU1    NV18     X      0-51    0               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

==============================
     Environment Variables
==============================
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

Models quantized with GPTQModel https://github.com/ModelCloud/GPTQModel/ , whether the latest published version on Pypi (2.2.0) or compiling from source (this commit 5d2911a4b2a709afb0941d53c3882d0cd80b9649 ) (4.0.0-dev) cannot be load in vllm 0.9.2 (but could be loaded on vllm 0.9.1)
I tried Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-32B-Instruct. I assume it would replicate on the 7B and other versions.

Command to start vllm (both 0.9.1 and 0.9.2):

python3 -m vllm.entrypoints.openai.api_server --model 3B --generation-config vllm --max-model-len 32768 -tp 1 --limit_mm_per_prompt '{"images": 6, "videos": 0}'
Output of vllm 0.9.2 INFO 07-15 12:50:00 [__init__.py:244] Automatically detected platform cuda. INFO 07-15 12:50:03 [api_server.py:1395] vLLM API server version 0.9.2 INFO 07-15 12:50:03 [cli_args.py:325] non-default args: {'model': '../GPTQModel/3B', 'max_model_len': 32768, 'generation_config': 'vllm', 'limit_mm_per_prompt': {'images': 6, 'videos': 0}} INFO 07-15 12:50:07 [config.py:841] This model supports multiple tasks: {'reward', 'embed', 'classify', 'generate'}. Defaulting to 'generate'. INFO 07-15 12:50:07 [config.py:1472] Using max model len 32768 INFO 07-15 12:50:07 [gptq_marlin.py:170] The model is convertible to gptq_marlin during runtime. Using gptq_marlin kernel. INFO 07-15 12:50:08 [config.py:2285] Chunked prefill is enabled with max_num_batched_tokens=8192. INFO 07-15 12:50:11 [__init__.py:244] Automatically detected platform cuda. INFO 07-15 12:50:13 [core.py:526] Waiting for init message from front-end. INFO 07-15 12:50:13 [core.py:69] Initializing a V1 LLM engine (v0.9.2) with config: model='../GPTQModel/3B', speculative_config=None, tokenizer='../GPTQModel/3B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=gptq_marlin, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=../GPTQModel/3B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null} INFO 07-15 12:50:19 [parallel_state.py:1076] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0 WARNING 07-15 12:50:19 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer. INFO 07-15 12:50:19 [gpu_model_runner.py:1770] Starting to load model ../GPTQModel/3B... INFO 07-15 12:50:19 [gpu_model_runner.py:1775] Loading model from scratch... WARNING 07-15 12:50:19 [vision.py:91] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend. INFO 07-15 12:50:20 [gptq_marlin.py:266] Using MacheteLinearKernel for GPTQMarlinLinearMethod INFO 07-15 12:50:20 [cuda.py:284] Using Flash Attention backend on V1 engine. Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00INFO 07-15 12:50:20 [default_loader.py:272] Loading weights took 0.69 seconds
INFO 07-15 12:50:21 [gpu_model_runner.py:1801] Model loading took 3.3225 GiB and 1.205221 seconds
INFO 07-15 12:50:21 [gpu_model_runner.py:2238] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 1 image items of the maximum feature size.
Using a slow image processor as use_fast is unset and a slow processor was saved with this model. use_fast=True will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with use_fast=False.
ERROR 07-15 12:50:26 [core.py:586] EngineCore failed to start.
ERROR 07-15 12:50:26 [core.py:586] Traceback (most recent call last):
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 577, in run_engine_core
ERROR 07-15 12:50:26 [core.py:586] engine_core = EngineCoreProc(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 404, in init
ERROR 07-15 12:50:26 [core.py:586] super().init(vllm_config, executor_class, log_stats,
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 82, in init
ERROR 07-15 12:50:26 [core.py:586] self._initialize_kv_caches(vllm_config)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
ERROR 07-15 12:50:26 [core.py:586] available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
ERROR 07-15 12:50:26 [core.py:586] output = self.collective_rpc("determine_available_memory")
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 07-15 12:50:26 [core.py:586] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/utils/init.py", line 2736, in run_method
ERROR 07-15 12:50:26 [core.py:586] return func(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 07-15 12:50:26 [core.py:586] return func(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_worker.py", line 210, in determine_available_memory
ERROR 07-15 12:50:26 [core.py:586] self.model_runner.profile_run()
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2274, in profile_run
ERROR 07-15 12:50:26 [core.py:586] = self._dummy_run(self.max_num_tokens, is_profile=True)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 07-15 12:50:26 [core.py:586] return func(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2057, in _dummy_run
ERROR 07-15 12:50:26 [core.py:586] outputs = model(
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 07-15 12:50:26 [core.py:586] return self._call_impl(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 07-15 12:50:26 [core.py:586] return forward_call(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2_5_vl.py", line 1145, in forward
ERROR 07-15 12:50:26 [core.py:586] hidden_states = self.language_model.model(
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 239, in call
ERROR 07-15 12:50:26 [core.py:586] output = self.compiled_callable(*args, **kwargs)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
ERROR 07-15 12:50:26 [core.py:586] raise e.with_traceback(None) from None
ERROR 07-15 12:50:26 [core.py:586] torch._dynamo.exc.Unsupported: All torch_function overrides for call TorchInGraphFunctionVariable(C.machete_mm, nonstrict_traceable=False) with args [TensorVariable(), TensorVariable(), LazyVariableTracker(), LazyVariableTracker(), TensorVariable(), UserDefinedObjectVariable(PackedColumnParameter), LazyVariableTracker(), LazyVariableTracker(), LazyVariableTracker(), LazyVariableTracker()] and kwargs {} returned NotImplemented
ERROR 07-15 12:50:26 [core.py:586]
ERROR 07-15 12:50:26 [core.py:586] from user code:
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 355, in forward
ERROR 07-15 12:50:26 [core.py:586] hidden_states, residual = layer(
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 254, in forward
ERROR 07-15 12:50:26 [core.py:586] hidden_states = self.self_attn(
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 181, in forward
ERROR 07-15 12:50:26 [core.py:586] qkv, _ = self.qkv_proj(hidden_states)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 510, in forward
ERROR 07-15 12:50:26 [core.py:586] output_parallel = self.quant_method.apply(self, input
, bias)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/gptq_marlin.py", line 372, in apply
ERROR 07-15 12:50:26 [core.py:586] return self.kernel.apply_weights(layer, x, bias)
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py", line 122, in apply_weights
ERROR 07-15 12:50:26 [core.py:586] output = ops.machete_mm(a=x_2d,
ERROR 07-15 12:50:26 [core.py:586] File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/_custom_ops.py", line 1098, in machete_mm
ERROR 07-15 12:50:26 [core.py:586] return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
ERROR 07-15 12:50:26 [core.py:586]
ERROR 07-15 12:50:26 [core.py:586] Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
ERROR 07-15 12:50:26 [core.py:586]
Process EngineCore_0:
Traceback (most recent call last):
File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 590, in run_engine_core
raise e
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 577, in run_engine_core
engine_core = EngineCoreProc(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 404, in init
super().init(vllm_config, executor_class, log_stats,
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 82, in init
self._initialize_kv_caches(vllm_config)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
available_gpu_memory = self.model_executor.determine_available_memory()
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
output = self.collective_rpc("determine_available_memory")
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
answer = run_method(self.driver_worker, method, args, kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/utils/init.py", line 2736, in run_method
return func(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_worker.py", line 210, in determine_available_memory
self.model_runner.profile_run()
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2274, in profile_run
= self._dummy_run(self.max_num_tokens, is_profile=True)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 2057, in _dummy_run
outputs = model(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2_5_vl.py", line 1145, in forward
hidden_states = self.language_model.model(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 239, in call
output = self.compiled_callable(*args, **kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
raise e.with_traceback(None) from None
torch._dynamo.exc.Unsupported: All torch_function overrides for call TorchInGraphFunctionVariable(_C.machete_mm, nonstrict_traceable=False) with args [TensorVariable(), TensorVariable(), LazyVariableTracker(), LazyVariableTracker(), TensorVariable(), UserDefinedObjectVariable(PackedColumnParameter), LazyVariableTracker(), LazyVariableTracker(), LazyVariableTracker(), LazyVariableTracker()] and kwargs {} returned NotImplemented

from user code:
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 355, in forward
hidden_states, residual = layer(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 254, in forward
hidden_states = self.self_attn(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 181, in forward
qkv, _ = self.qkv_proj(hidden_states)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 510, in forward
output_parallel = self.quant_method.apply(self, input_, bias)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/gptq_marlin.py", line 372, in apply
return self.kernel.apply_weights(layer, x, bias)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py", line 122, in apply_weights
output = ops.machete_mm(a=x_2d,
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/_custom_ops.py", line 1098, in machete_mm
return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

[rank0]:[W715 12:50:26.177353314 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 1495, in
uvloop.run(run_server(args))
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/uvloop/init.py", line 82, in run
return loop.run_until_complete(wrapper())
File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/uvloop/init.py", line 61, in wrapper
return await main
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 1431, in run_server
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 1451, in run_server_worker
async with build_async_engine_client(args, client_config) as engine_client:
File "/usr/lib/python3.10/contextlib.py", line 199, in aenter
return await anext(self.gen)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 158, in build_async_engine_client
async with build_async_engine_client_from_engine_args(
File "/usr/lib/python3.10/contextlib.py", line 199, in aenter
return await anext(self.gen)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 194, in build_async_engine_client_from_engine_args
async_llm = AsyncLLM.from_vllm_config(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/async_llm.py", line 162, in from_vllm_config
return cls(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/async_llm.py", line 124, in init
self.engine_core = EngineCoreClient.make_async_mp_client(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core_client.py", line 96, in make_async_mp_client
return AsyncMPClient(*client_args)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core_client.py", line 666, in init
super().init(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/core_client.py", line 403, in init
with launch_core_engines(vllm_config, executor_class,
File "/usr/lib/python3.10/contextlib.py", line 142, in exit
next(self.gen)
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/utils.py", line 434, in launch_core_engines
wait_for_engine_startup(
File "/home/soren/latestvllm/venv/lib/python3.10/site-packages/vllm/v1/engine/utils.py", line 484, in wait_for_engine_startup
raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}

Output of vllm 0.9.1 (no bug) INFO 07-15 12:47:01 [__init__.py:244] Automatically detected platform cuda. INFO 07-15 12:47:04 [api_server.py:1287] vLLM API server version 0.9.1 INFO 07-15 12:47:05 [cli_args.py:309] non-default args: {'model': '../GPTQModel/3B', 'max_model_len': 32768, 'generation_config': 'vllm', 'limit_mm_per_prompt': {'images': 6, 'videos': 0}} INFO 07-15 12:47:12 [config.py:823] This model supports multiple tasks: {'score', 'classify', 'embed', 'generate', 'reward'}. Defaulting to 'generate'. INFO 07-15 12:47:13 [gptq_marlin.py:145] The model is convertible to gptq_marlin during runtime. Using gptq_marlin kernel. INFO 07-15 12:47:13 [config.py:2195] Chunked prefill is enabled with max_num_batched_tokens=8192. WARNING 07-15 12:47:15 [env_override.py:17] NCCL_CUMEM_ENABLE is set to 0, skipping override. This may increase memory overhead with cudagraph+allreduce: https://github.com/NVIDIA/nccl/issues/1234 INFO 07-15 12:47:17 [__init__.py:244] Automatically detected platform cuda. INFO 07-15 12:47:20 [core.py:455] Waiting for init message from front-end. INFO 07-15 12:47:20 [core.py:70] Initializing a V1 LLM engine (v0.9.1) with config: model='../GPTQModel/3B', speculative_config=None, tokenizer='../GPTQModel/3B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=gptq_marlin, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=../GPTQModel/3B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null} WARNING 07-15 12:47:20 [utils.py:2737] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in INFO 07-15 12:47:26 [parallel_state.py:1065] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0 Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`. WARNING 07-15 12:47:31 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer. INFO 07-15 12:47:31 [gpu_model_runner.py:1595] Starting to load model ../GPTQModel/3B... INFO 07-15 12:47:31 [gpu_model_runner.py:1600] Loading model from scratch... WARNING 07-15 12:47:31 [vision.py:91] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend. INFO 07-15 12:47:31 [gptq_marlin.py:240] Using MacheteLinearKernel for GPTQMarlinLinearMethod INFO 07-15 12:47:32 [cuda.py:252] Using Flash Attention backend on V1 engine. Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00INFO 07-15 12:47:32 [default_loader.py:272] Loading weights took 0.75 seconds
INFO 07-15 12:47:33 [gpu_model_runner.py:1624] Model loading took 3.3225 GiB and 1.250191 seconds
INFO 07-15 12:47:35 [gpu_model_runner.py:1978] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 1 image items of the maximum feature size.
INFO 07-15 12:47:46 [backends.py:462] Using cache directory: /home/soren/.cache/vllm/torch_compile_cache/03f03888ae/rank_0_0 for vLLM's torch.compile
INFO 07-15 12:47:46 [backends.py:472] Dynamo bytecode transform time: 8.82 s
INFO 07-15 12:47:49 [backends.py:161] Cache the graph of shape None for later use

INFO 07-15 12:48:18 [backends.py:173] Compiling a graph for general shape takes 31.69 s
INFO 07-15 12:48:41 [monitor.py:34] torch.compile takes 40.52 s in total
INFO 07-15 12:48:42 [gpu_worker.py:227] Available KV cache memory: 61.88 GiB
INFO 07-15 12:48:43 [kv_cache_utils.py:715] GPU KV cache size: 1,802,496 tokens
INFO 07-15 12:48:43 [kv_cache_utils.py:719] Maximum concurrency for 32,768 tokens per request: 55.01x
INFO 07-15 12:49:14 [gpu_model_runner.py:2048] Graph capturing finished in 31 secs, took 0.80 GiB
INFO 07-15 12:49:14 [core.py:171] init engine (profile, create kv cache, warmup model) took 101.19 seconds
INFO 07-15 12:49:16 [loggers.py:137] Engine 000: vllm cache_config_info with initialization after num_gpu_blocks is: 112656
INFO 07-15 12:49:17 [api_server.py:1349] Starting vLLM API server 0 on http://0.0.0.0:8000
INFO 07-15 12:49:17 [launcher.py:29] Available routes are:
INFO 07-15 12:49:17 [launcher.py:37] Route: /openapi.json, Methods: HEAD, GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /docs, Methods: HEAD, GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /docs/oauth2-redirect, Methods: HEAD, GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /redoc, Methods: HEAD, GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /health, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /load, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /ping, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /ping, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /tokenize, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /detokenize, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/models, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /version, Methods: GET
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/chat/completions, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/completions, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/embeddings, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /pooling, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /classify, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /score, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/score, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/audio/transcriptions, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /rerank, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v1/rerank, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /v2/rerank, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /invocations, Methods: POST
INFO 07-15 12:49:17 [launcher.py:37] Route: /metrics, Methods: GET
INFO: Started server process [67381]
INFO: Waiting for application startup.
INFO: Application startup complete.

For convenience, I have uploaded the quantized model here:
https://huggingface.co/NM-dev/Qwen2.5-VL-3B-Instruct-GPTQModel-W4A16-G128

If you want to reproduce the model and quantize it yourself:

#!/usr/bin/env python
# coding: utf-8

import os
import random
import json
import io
import base64
import types
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.models.base import BaseGPTQModel
from transformers import AutoProcessor
from PIL import Image

SEED = 42
model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
quant_path = f"3B"
random.seed(SEED)

quant_config = QuantizeConfig(bits=4, group_size=128)
model = GPTQModel.load(model_path, quant_config)
processor = AutoProcessor.from_pretrained(model_path)

calibration_dataset = load_dataset(
    "HuggingFaceM4/ChartQA",
    split="train"
  ).select(range(256))

calib_data = []
for sample in calibration_dataset:
    try:
        messages = []
        messages.append(
            {
                "role": "user",
                "content": [
                    {
                        "type": "image", 
                        "image_url": sample["image"]
                    },
                    {
                        "type": "text", "text": sample["query"],
                    },
                ],
            }
        )
        messages.append(
            {
                "role": "assistant",
                "content": sample["label"][0],
            },
        )
        text = processor.tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True, 
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=text, 
            images=image_inputs, 
            videos=video_inputs, 
            padding=True, 
            return_tensors="pt"
        )
        calib_data.append(inputs)
    except Exception as e:
        print(e)
        pass
        
model.prepare_dataset = lambda *args, **kwargs: BaseGPTQModel.prepare_dataset(model, *args, **kwargs)

model.quantize(calib_data, batch_size=1)
model.save(quant_path)
processor.save_pretrained(quant_path)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions