Skip to content

[Bugfix] Fix structured output crash on CPU due to pin_memory=True#37706

Merged
njhill merged 8 commits intovllm-project:mainfrom
wjhrdy:fix/cpu-structured-output-pin-memory
Mar 24, 2026
Merged

[Bugfix] Fix structured output crash on CPU due to pin_memory=True#37706
njhill merged 8 commits intovllm-project:mainfrom
wjhrdy:fix/cpu-structured-output-pin-memory

Conversation

@wjhrdy
Copy link
Copy Markdown
Contributor

@wjhrdy wjhrdy commented Mar 20, 2026

Essential Checks

  • PR title follows the pattern [Tag] Short description
  • I have searched for related issues and checked existing PRs
  • I have run linting/formatting locally

Purpose

Fix RuntimeError: pin_memory=True requires a CUDA or other accelerator backend crash when using structured output (guided decoding) on CPU-only deployments.

Fixes #37705

Problem

apply_grammar_bitmask() in vllm/v1/structured_output/utils.py crashes on CPU when handling mixed batches (concurrent structured + non-structured requests):

  1. pin_memory=True is hardcodedtorch.tensor(out_indices, ..., pin_memory=True) requires CUDA; fails on CPU-only systems.
  2. xgrammar CPU kernel expects Sequence[int], not torch.Tensorapply_token_bitmask_inplace_cpu() only accepts a Python list for the indices argument.

Note: the existing CPU float32 workaround (added in #31901) was never reachable because the pin_memory=True crash occurs first.

Fix

On CPU, pass out_indices as a plain Python list directly instead of converting to a pinned tensor. The GPU path with pinned memory is preserved.

Test Plan

Tested by starting vLLM on CPU with ibm-granite/granite-3.2-2b-instruct, then sending concurrent plain + structured output (response_format: json_schema) requests. Without the fix, both requests return 500 and the EngineCore dies. With the fix, both succeed and the server stays healthy.

import concurrent.futures
from openai import OpenAI

client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
MODEL = "ibm-granite/granite-3.2-2b-instruct"

def plain_request():
    return client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": "Tell me a story"}],
        max_tokens=200,
    )

def structured_request():
    return client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": "What is the capital of France?"}],
        max_tokens=50,
        response_format={
            "type": "json_schema",
            "json_schema": {
                "name": "resp", "strict": True,
                "schema": {
                    "type": "object",
                    "properties": {"capital": {"type": "string"}},
                    "required": ["capital"],
                    "additionalProperties": False,
                },
            },
        },
    )

with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
    f1 = executor.submit(plain_request)
    f2 = executor.submit(structured_request)
    print(f1.result())
    print(f2.result())

On CPU-only deployments, `apply_grammar_bitmask()` crashes with
`RuntimeError: pin_memory=True requires a CUDA or other accelerator
backend` when handling mixed batches of structured and non-structured
requests.

Two issues:
1. `pin_memory=True` is hardcoded in the `torch.tensor()` call for
   `out_indices` — this requires CUDA and fails on CPU.
2. The xgrammar CPU kernel (`apply_token_bitmask_inplace_cpu`)
   expects `Sequence[int]` for the `indices` argument, not a tensor.

Note: the existing CPU float32 workaround added in vllm-project#31901 was never
reachable because the `pin_memory=True` crash occurs first.

Fix: on CPU, pass `out_indices` as a plain Python list. The GPU path
with pinned memory is preserved.

Fixes vllm-project#37705

Signed-off-by: Willy Hardy <whardy@redhat.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively resolves a critical RuntimeError that occurred on CPU-only deployments due to pin_memory=True being hardcoded for torch.tensor creation. The changes correctly introduce conditional logic to handle CPU and GPU devices separately, ensuring that pin_memory=True is only applied when a CUDA device is available. Furthermore, it addresses the xgrammar CPU kernel's expectation of a Python list for indices by passing out_indices directly on CPU, which is a significant improvement for correctness and stability in mixed-batch scenarios. The updated type hint for indices also enhances code clarity.

Comment on lines +111 to +119
if logits.device.type == "cpu":
# On CPU, pass indices as a plain list — pin_memory requires CUDA,
# and the xgrammar CPU kernel expects Sequence[int], not a tensor.
indices = out_indices
else:
indices = torch.tensor(
out_indices, dtype=torch.int32, device="cpu", pin_memory=True,
)
indices = indices.to(logits.device, non_blocking=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This conditional logic is a critical fix. By checking logits.device.type, the code now correctly avoids setting pin_memory=True on CPU, which was causing a RuntimeError. Additionally, passing out_indices as a plain Python list for CPU devices directly addresses the xgrammar CPU kernel's expectation for a Sequence[int], preventing potential issues with type mismatches.

Copy link
Copy Markdown
Contributor

@dougbtv dougbtv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks excellent -- do we need any validation on the testing side?

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me, thanks!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 20, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 20, 2026

Hi @wjhrdy, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Willy Hardy <whardy@redhat.com>
@wjhrdy wjhrdy force-pushed the fix/cpu-structured-output-pin-memory branch from 030f141 to e97ab92 Compare March 20, 2026 20:54
Copy link
Copy Markdown
Collaborator

@andy-neuma andy-neuma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

wjhrdy and others added 2 commits March 23, 2026 09:35
- Use logits.is_cpu instead of logits.device.type == "cpu"
- Restore original comment explaining non_blocking tensor copy in else branch
- Consolidate tensor creation formatting

Signed-off-by: Will Hardy <whardy@redhat.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Will Hardy <whardy@redhat.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
if logits.is_cpu:
# On CPU, pass indices as a plain list — pin_memory requires CUDA,
# and the xgrammar CPU kernel expects Sequence[int], not a tensor.
indices = out_indices
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why rename this variable?

Address review feedback from benchislett: use is_pin_memory_available()
for pin_memory instead of branching on device type. This eliminates
the CPU-specific code path entirely. Also reverts variable name back
to index_tensor (original name).

Signed-off-by: Will Hardy <whardy@redhat.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@wjhrdy
Copy link
Copy Markdown
Contributor Author

wjhrdy commented Mar 23, 2026

Addressed all review feedback:

  • @njhill: Used logits.is_cpu, restored the original comment in the else branch, cleaned up formatting
  • @benchislett: Simplified to use is_pin_memory_available() instead of device-type branching — no more CPU-specific code path needed. Reverted variable name back to index_tensor (original name).

Note: the CPU machine I normally test these changes on is currently down, so this latest update is untested. Will validate once the machine is back up.

njhill added 2 commits March 24, 2026 09:13
Signed-off-by: Nick Hill <nickhill123@gmail.com>
@njhill
Copy link
Copy Markdown
Member

njhill commented Mar 24, 2026

Thanks @wjhrdy. I reworked it a bit to separate the CPU and non-CPU cases after all since I noticed that for CPU, xgrammar just converts the tensor back to a list, and there was already some cpu-specific logic.

@njhill njhill enabled auto-merge (squash) March 24, 2026 16:22
@njhill njhill merged commit 057fc94 into vllm-project:main Mar 24, 2026
49 checks passed
RhizoNymph pushed a commit to RhizoNymph/vllm that referenced this pull request Mar 26, 2026
…llm-project#37706)

Signed-off-by: Willy Hardy <whardy@redhat.com>
Signed-off-by: Will Hardy <whardy@redhat.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Mar 27, 2026
…llm-project#37706)

Signed-off-by: Willy Hardy <whardy@redhat.com>
Signed-off-by: Will Hardy <whardy@redhat.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
malaiwah pushed a commit to malaiwah/vllm that referenced this pull request Mar 27, 2026
…llm-project#37706)

Signed-off-by: Willy Hardy <whardy@redhat.com>
Signed-off-by: Will Hardy <whardy@redhat.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…llm-project#37706)

Signed-off-by: Willy Hardy <whardy@redhat.com>
Signed-off-by: Will Hardy <whardy@redhat.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…llm-project#37706)

Signed-off-by: Willy Hardy <whardy@redhat.com>
Signed-off-by: Will Hardy <whardy@redhat.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…llm-project#37706)

Signed-off-by: Willy Hardy <whardy@redhat.com>
Signed-off-by: Will Hardy <whardy@redhat.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…llm-project#37706)

Signed-off-by: Willy Hardy <whardy@redhat.com>
Signed-off-by: Will Hardy <whardy@redhat.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…llm-project#37706)

Signed-off-by: Willy Hardy <whardy@redhat.com>
Signed-off-by: Will Hardy <whardy@redhat.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed structured-output v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: Structured output crashes on CPU with pin_memory=True in apply_grammar_bitmask()

6 participants