Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Nov 7, 2025

This PR #28012 is breaking PD deployments with TP>1.

# Spin up P TP=2
vllm serve Qwen/Qwen3-0.6B --port $(just port 8100) --enforce-eager --enable-log-requests --tensor-parallel-size 2 --gpu-memory-utilization 0.4 --trust-remote-code --max-model-len 32768 --block-size 128 --data_parallel_size 1 --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
# spin up toy_proxy_server in the bg..

# send test request 
  curl -X POST http://localhost:$(just port 8192)/v1/completions \
    -H "Content-Type: application/json" \
    -d '{ \
      "model": "{{MODEL}}", \
      "prompt": "Can you complete this latin sentence: Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.", \
      "max_tokens": 150, \
      "temperature": 0.2 \
    }'

# Observe request getting stuck 
(APIServer pid=2822821) INFO 11-07 11:08:44 [launcher.py:46] Route: /invocations, Methods: POST
(APIServer pid=2822821) INFO 11-07 11:08:44 [launcher.py:46] Route: /metrics, Methods: GET
(APIServer pid=2822821) INFO:     Started server process [2822821]
(APIServer pid=2822821) INFO:     Waiting for application startup.
(APIServer pid=2822821) INFO:     Application startup complete.
(APIServer pid=2822821) INFO 11-07 11:10:17 [logger.py:47] Received request cmpl-700daa72-8cc7-4a65-9bff-cb83c40b16a7-0: params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.2, top_p=0.95, top_k=20, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=1, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args={'kv_transfer_params': {'do_remote_decode': True, 'do_remote_prefill': False, 'remote_engine_id': None, 'remote_block_ids': None, 'remote_host': None, 'remote_port': None}}), lora_request: None.
(APIServer pid=2822821) INFO 11-07 11:10:17 [async_llm.py:343] Added request cmpl-700daa72-8cc7-4a65-9bff-cb83c40b16a7-0.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request reverts a previous performance-related change (#28012) that was causing issues in deployments with tensor parallelism greater than one. The revert re-introduces a dedicated I/O thread in the multiprocessor executor to handle worker responses, removing the custom FutureWrapper and its associated logic. The changes touch several components, including executors, KV connector utilities, and tests, to align with the restored asynchronous execution model. My review identified a critical issue with an incorrect type hint that could lead to runtime errors, and a high-severity thread-safety concern in the asynchronous aggregation logic that, while not causing a bug with the current configuration, is fragile and should be addressed to prevent future issues.

@torch.inference_mode()
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
self, grammar_output: "GrammarOutput"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The type hint for grammar_output has been changed to GrammarOutput, but callers of this method (e.g., in vllm/v1/executor/abstract.py) can pass None. This creates a discrepancy between the type hint and the actual usage, and can lead to a runtime AttributeError if None is passed and its attributes are accessed downstream. The type hint should be reverted to GrammarOutput | None to accurately reflect that None is a valid value.

Suggested change
self, grammar_output: "GrammarOutput"
self, grammar_output: "GrammarOutput | None"

Comment on lines +231 to +255
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
remaining = len(output_futures)

def make_callback(idx):
def callback(fut):
if result_future.done():
return

try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)

# this check assumes io_thread_pool uses a single thread
nonlocal remaining
remaining -= 1
if not remaining:
result_future.set_result(self.aggregate(outputs, output_rank))

return callback

for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of async_aggregate is not thread-safe. The remaining counter is accessed and modified without a lock. While the comment on line 246 correctly points out the assumption of a single-threaded I/O pool, this design is fragile. If the ThreadPoolExecutor in MultiprocExecutor is ever configured with more than one worker, this will introduce a race condition, leading to incorrect behavior. To make this implementation robust and thread-safe, a lock should be used to protect the shared remaining counter.

        from threading import Lock
        outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
        remaining = len(output_futures)
        lock = Lock()

        def make_callback(idx):
            def callback(fut):
                if result_future.done():
                    return

                try:
                    outputs[idx] = fut.result()
                except CancelledError:
                    result_future.cancel()
                except Exception as e:
                    result_future.set_exception(e)

                with lock:
                    # This check is now thread-safe.
                    nonlocal remaining
                    remaining -= 1
                    if not remaining:
                        if not result_future.done():
                            result_future.set_result(self.aggregate(outputs, output_rank))

            return callback

        for i, output_future in enumerate(output_futures):
            output_future.add_done_callback(make_callback(i))

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +144 to +156
def __init__(self, refs, aggregator: KVOutputAggregator | None = None):
super().__init__()
self.ref_or_refs = ref_or_refs
self.refs = refs
self.aggregator = aggregator

def result(self, timeout=None):
if timeout is not None:
raise NotImplementedError("timeout is not supported")

outputs = ray.get(self.ref_or_refs, timeout=timeout)
if self.aggregator is None:
return outputs
return self.refs[0].get()

outputs = [ref.get() for ref in self.refs]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Ray futures retrieved via nonexistent ObjectRef.get

The new FutureWrapper.result() calls self.refs[0].get() and [ref.get() for ref in self.refs]. Ray object references don’t expose a .get() method; they must be resolved with ray.get(ref) or ray.get(refs). This means any non-blocking path that returns a FutureWrapper will immediately raise AttributeError when the scheduler awaits the result, breaking Ray execution entirely.

Useful? React with 👍 / 👎.

Comment on lines 435 to +445

# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs[0])
return FutureWrapper(refs)

# Get output from all workers when connector is present
assert self.kv_output_aggregator is not None
if not non_block:
# Block and get results from all workers
return self.kv_output_aggregator.aggregate(ray.get(refs))
outputs = [ref.get() for ref in refs]
return self.kv_output_aggregator.aggregate(outputs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Blocking Ray sampling uses .get() on ObjectRef

RayDistributedExecutor.sample_tokens now resolves Ray outputs with refs[0].get() and [ref.get() for ref in refs]. ObjectRef does not provide a get() API, so any synchronous sampling (with or without a KV connector) will fail with AttributeError before a response is returned. Use ray.get(refs) to retrieve the values.

Useful? React with 👍 / 👎.

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 7, 2025
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) November 7, 2025 14:36
@DarkLight1337 DarkLight1337 merged commit 68a72a5 into vllm-project:main Nov 7, 2025
54 checks passed
@NickLucche NickLucche deleted the fix-nixl-thread branch November 7, 2025 15:08
@njhill
Copy link
Member

njhill commented Nov 7, 2025

Here is the change again with a fix: #28319

Thanks @NickLucche and @DarkLight1337

ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Nov 13, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants