Skip to content

[Bugfix] Fix ROCm crash in qwen3_next multi-stream events (#36795)#37427

Merged
yewentao256 merged 1 commit intovllm-project:mainfrom
JartX:fix/rocm-qwen3-next-multi-stream-events
Mar 18, 2026
Merged

[Bugfix] Fix ROCm crash in qwen3_next multi-stream events (#36795)#37427
yewentao256 merged 1 commit intovllm-project:mainfrom
JartX:fix/rocm-qwen3-next-multi-stream-events

Conversation

@JartX
Copy link
Copy Markdown
Contributor

@JartX JartX commented Mar 18, 2026

PR #36795 introduced maybe_execute_in_parallel for qwen3_next but gated CUDA event creation on is_cuda() while the aux stream uses is_cuda_alike(), causing AttributeError: 'NoneType' object has no attribute 'record' on ROCm
Fix: change event guard from is_cuda() to is_cuda_alike() to match the stream check

Error:


(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 931, in worker_busy_loop
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     output = func(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]              ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return func(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 388, in determine_available_memory
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     self.model_runner.profile_run()
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 5541, in profile_run
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     hidden_states, last_hidden_states = self._dummy_run(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                                         ^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return func(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 5235, in _dummy_run
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     outputs = self.model(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]               ^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/cuda_graph.py", line 251, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self.runnable(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._call_impl(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return forward_call(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_5.py", line 769, in forward
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     hidden_states = self.language_model.model(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 597, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     output = TorchCompileWithNoGuardsWrapper.__call__(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/wrapper.py", line 182, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._call_with_optional_nvtx_range(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/wrapper.py", line 76, in _call_with_optional_nvtx_range
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return callable_fn(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 832, in compile_wrapper
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return fn(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_next.py", line 1384, in forward
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     def forward(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return fn(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/caching.py", line 206, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self.optimized_call(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 837, in call_wrapped
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._wrapped_call(self, *args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 413, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     raise e
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._call_impl(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return forward_call(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "<eval_with_key>.82", line 249, in forward
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     submod_0 = self.submod_0(l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_inputs_embeds_, s59, s18);  l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = None
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/cuda_graph.py", line 251, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self.runnable(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/piecewise_backend.py", line 367, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return range_entry.runnable(*args)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/standalone_compile.py", line 63, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._compiled_fn(*args)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return fn(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1130, in forward
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return compiled_fn(full_args)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 353, in runtime_wrapper
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     all_outs = call_func_at_runtime_with_args(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 129, in call_func_at_runtime_with_args
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     out = normalize_as_list(f(args))
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                             ^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 526, in wrapper
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return compiled_fn(runtime_args)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 613, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self.current_callable(inputs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 3017, in run
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     out = model(new_inputs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]           ^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/tmp/torchinductor_root/un/cunlm4f5zs2ziwfx2w2p4ygiqtcynp5rms26a6vfiztpeyw3wpyo.py", line 625, in call
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     buf2 = torch.ops.vllm.gdn_in_proj.default(buf1, 6144, 32, 'language_model.model.layers.0.linear_attn')
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 841, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._op(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_next.py", line 1710, in gdn_in_proj
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._forward_in_proj(hidden_states)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_next.py", line 803, in _forward_in_proj
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils/multi_stream_utils.py", line 38, in maybe_execute_in_parallel
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     event0.record()
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     ^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936] AttributeError: 'NoneType' object has no attribute 'record'
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936] Traceback (most recent call last):
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 931, in worker_busy_loop
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     output = func(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]              ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return func(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 388, in determine_available_memory
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     self.model_runner.profile_run()
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 5541, in profile_run
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     hidden_states, last_hidden_states = self._dummy_run(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                                         ^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return func(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 5235, in _dummy_run
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     outputs = self.model(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]               ^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/cuda_graph.py", line 251, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self.runnable(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._call_impl(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return forward_call(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_5.py", line 769, in forward
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     hidden_states = self.language_model.model(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 597, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     output = TorchCompileWithNoGuardsWrapper.__call__(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/wrapper.py", line 182, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._call_with_optional_nvtx_range(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/wrapper.py", line 76, in _call_with_optional_nvtx_range
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return callable_fn(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 832, in compile_wrapper
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return fn(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_next.py", line 1384, in forward
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     def forward(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return fn(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/caching.py", line 206, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self.optimized_call(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 837, in call_wrapped
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._wrapped_call(self, *args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 413, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     raise e
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._call_impl(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return forward_call(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "<eval_with_key>.82", line 249, in forward
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     submod_0 = self.submod_0(l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_inputs_embeds_, s59, s18);  l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = None
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/cuda_graph.py", line 251, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self.runnable(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/piecewise_backend.py", line 367, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return range_entry.runnable(*args)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/standalone_compile.py", line 63, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._compiled_fn(*args)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return fn(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1130, in forward
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return compiled_fn(full_args)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 353, in runtime_wrapper
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     all_outs = call_func_at_runtime_with_args(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 129, in call_func_at_runtime_with_args
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     out = normalize_as_list(f(args))
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                             ^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 526, in wrapper
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return compiled_fn(runtime_args)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 613, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self.current_callable(inputs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 3017, in run
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     out = model(new_inputs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]           ^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/tmp/torchinductor_root/un/cunlm4f5zs2ziwfx2w2p4ygiqtcynp5rms26a6vfiztpeyw3wpyo.py", line 625, in call
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     buf2 = torch.ops.vllm.gdn_in_proj.default(buf1, 6144, 32, 'language_model.model.layers.0.linear_attn')
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 841, in __call__
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._op(*args, **kwargs)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_next.py", line 1710, in gdn_in_proj
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     return self._forward_in_proj(hidden_states)
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_next.py", line 803, in _forward_in_proj
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils/multi_stream_utils.py", line 38, in maybe_execute_in_parallel
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     event0.record()
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936]     ^^^^^^^^^^^^^
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936] AttributeError: 'NoneType' object has no attribute 'record'
(Worker_TP0 pid=50999) ERROR 03-18 07:36:48 [multiproc_executor.py:936] 
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099] EngineCore failed to start.
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099] Traceback (most recent call last):
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 1073, in run_engine_core
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]     engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]   File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 839, in __init__
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]     super().__init__(
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 122, in __init__
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]     kv_cache_config = self._initialize_kv_caches(vllm_config)
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]   File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 245, in _initialize_kv_caches
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]     available_gpu_memory = self.model_executor.determine_available_memory()
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 136, in determine_available_memory
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]     return self.collective_rpc("determine_available_memory")
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 397, in collective_rpc
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]     return aggregate(get_response())
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]                      ^^^^^^^^^^^^^^
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 380, in get_response
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099]     raise RuntimeError(
(EngineCore pid=50897) ERROR 03-18 07:36:48 [core.py:1099] RuntimeError: Worker failed with error ''NoneType' object has no attribute 'record'', please check the stack trace above for the root cause
(Worker_TP0 pid=50999) WARNING 03-18 07:36:48 [multiproc_executor.py:870] WorkerProc was terminated
(Worker_TP1 pid=51000) WARNING 03-18 07:36:48 [multiproc_executor.py:870] WorkerProc was terminated
(EngineCore pid=50897) ERROR 03-18 07:36:50 [multiproc_executor.py:273] Worker proc VllmWorker-0 died unexpectedly, shutting down executor.
(EngineCore pid=50897) Process EngineCore:
(EngineCore pid=50897) Traceback (most recent call last):
(EngineCore pid=50897)   File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore pid=50897)     self.run()
(EngineCore pid=50897)   File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore pid=50897)     self._target(*self._args, **self._kwargs)
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 1103, in run_engine_core
(EngineCore pid=50897)     raise e
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 1073, in run_engine_core
(EngineCore pid=50897)     engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
(EngineCore pid=50897)                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=50897)     return func(*args, **kwargs)
(EngineCore pid=50897)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 839, in __init__
(EngineCore pid=50897)     super().__init__(
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 122, in __init__
(EngineCore pid=50897)     kv_cache_config = self._initialize_kv_caches(vllm_config)
(EngineCore pid=50897)                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=50897)     return func(*args, **kwargs)
(EngineCore pid=50897)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 245, in _initialize_kv_caches
(EngineCore pid=50897)     available_gpu_memory = self.model_executor.determine_available_memory()
(EngineCore pid=50897)                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 136, in determine_available_memory
(EngineCore pid=50897)     return self.collective_rpc("determine_available_memory")
(EngineCore pid=50897)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 397, in collective_rpc
(EngineCore pid=50897)     return aggregate(get_response())
(EngineCore pid=50897)                      ^^^^^^^^^^^^^^
(EngineCore pid=50897)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 380, in get_response
(EngineCore pid=50897)     raise RuntimeError(
(EngineCore pid=50897) RuntimeError: Worker failed with error ''NoneType' object has no attribute 'record'', please check the stack trace above for the root cause
(APIServer pid=50780) Traceback (most recent call last):
(APIServer pid=50780)   File "/usr/local/bin/vllm", line 10, in <module>
(APIServer pid=50780)     sys.exit(main())
(APIServer pid=50780)              ^^^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/cli/main.py", line 75, in main
(APIServer pid=50780)     args.dispatch_function(args)
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/cli/serve.py", line 118, in cmd
(APIServer pid=50780)     uvloop.run(run_server(args))
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 96, in run
(APIServer pid=50780)     return __asyncio.run(
(APIServer pid=50780)            ^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/lib/python3.12/asyncio/runners.py", line 195, in run
(APIServer pid=50780)     return runner.run(main)
(APIServer pid=50780)            ^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
(APIServer pid=50780)     return self._loop.run_until_complete(task)
(APIServer pid=50780)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 48, in wrapper
(APIServer pid=50780)     return await main
(APIServer pid=50780)            ^^^^^^^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 675, in run_server
(APIServer pid=50780)     await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 689, in run_server_worker
(APIServer pid=50780)     async with build_async_engine_client(
(APIServer pid=50780)                ^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
(APIServer pid=50780)     return await anext(self.gen)
(APIServer pid=50780)            ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 104, in build_async_engine_client
(APIServer pid=50780)     async with build_async_engine_client_from_engine_args(
(APIServer pid=50780)                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
(APIServer pid=50780)     return await anext(self.gen)
(APIServer pid=50780)            ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 145, in build_async_engine_client_from_engine_args
(APIServer pid=50780)     async_llm = AsyncLLM.from_vllm_config(
(APIServer pid=50780)                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 225, in from_vllm_config
(APIServer pid=50780)     return cls(
(APIServer pid=50780)            ^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 154, in __init__
(APIServer pid=50780)     self.engine_core = EngineCoreClient.make_async_mp_client(
(APIServer pid=50780)                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(APIServer pid=50780)     return func(*args, **kwargs)
(APIServer pid=50780)            ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 128, in make_async_mp_client
(APIServer pid=50780)     return AsyncMPClient(*client_args)
(APIServer pid=50780)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(APIServer pid=50780)     return func(*args, **kwargs)
(APIServer pid=50780)            ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 924, in __init__
(APIServer pid=50780)     super().__init__(
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 583, in __init__
(APIServer pid=50780)     with launch_core_engines(
(APIServer pid=50780)          ^^^^^^^^^^^^^^^^^^^^
(APIServer pid=50780)   File "/usr/lib/python3.12/contextlib.py", line 144, in __exit__
(APIServer pid=50780)     next(self.gen)
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/utils.py", line 972, in launch_core_engines
(APIServer pid=50780)     wait_for_engine_startup(
(APIServer pid=50780)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/utils.py", line 1031, in wait_for_engine_startup
(APIServer pid=50780)     raise RuntimeError(
(APIServer pid=50780) RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}
/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 2 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

@JartX JartX requested a review from sighingnow as a code owner March 18, 2026 12:32
@mergify mergify bot added qwen Related to Qwen models rocm Related to AMD ROCm bug Something isn't working labels Mar 18, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 18, 2026
@JartX
Copy link
Copy Markdown
Contributor Author

JartX commented Mar 18, 2026

@xyang16 @DarkLight1337 @tjtanaa check it please :)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a crash on ROCm devices in the qwen3_next model. The issue stems from CUDA events being initialized only on CUDA platforms (using is_cuda()), while the auxiliary stream used for parallel execution is created on all CUDA-like platforms, including ROCm. This discrepancy leads to None events being passed to operations that expect valid event objects on ROCm, causing an AttributeError. The fix correctly aligns the event creation logic with the stream creation logic by using is_cuda_alike(), ensuring that events are properly initialized on ROCm.

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
@xyang16
Copy link
Copy Markdown
Contributor

xyang16 commented Mar 18, 2026

Thanks for the fix! Sorry to break rocm.

@JartX
Copy link
Copy Markdown
Contributor Author

JartX commented Mar 18, 2026

@xyang16 No problem, that's what we're here for :)
Thanks for your contribution!

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 merged commit a913b61 into vllm-project:main Mar 18, 2026
58 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 18, 2026
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…ct#36795) (vllm-project#37427)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…ct#36795) (vllm-project#37427)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…ct#36795) (vllm-project#37427)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants