Skip to content
38 changes: 22 additions & 16 deletions vllm/v1/structured_output/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.import_utils import LazyLoader
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput

if TYPE_CHECKING:
Expand Down Expand Up @@ -106,28 +107,33 @@ def apply_grammar_bitmask(
# since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0]

index_tensor = None
if not skip_out_indices:
# xgrammar expects a python list of indices but it will actually work with
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
# manner and there should be no cpu sync within xgrammar.
index_tensor = torch.tensor(
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
)
index_tensor = index_tensor.to(logits.device, non_blocking=True)
if not logits.is_cpu:
index_tensor = None
if not skip_out_indices:
# xgrammar expects a python list of indices but it will actually work with
# a tensor. If we copy the tensor ourselves here we can do it in a
# non_blocking manner and there should be no cpu sync within xgrammar.
pin_memory = is_pin_memory_available()
index_tensor = torch.tensor(
out_indices, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
index_tensor = index_tensor.to(logits.device, non_blocking=True)

xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
return

# CPU case, use list for indices.
indices = None if skip_out_indices else out_indices
# Handle dtype conversion for CPU (older xgrammar CPU kernels require float32)
# See: https://github.com/vllm-project/vllm/issues/31901
if logits.device.type == "cpu" and logits.dtype != torch.float32:
if logits.dtype != torch.float32:
# Convert to float32, apply bitmask, then convert back
logits_float32 = logits.to(torch.float32)
xgr.apply_token_bitmask_inplace(
logits_float32, grammar_bitmask, indices=index_tensor
)
logits_fp32 = logits.to(torch.float32)
xgr.apply_token_bitmask_inplace(logits_fp32, grammar_bitmask, indices=indices)
# Copy the modified values back to the original tensor
logits.copy_(logits_float32.to(logits.dtype))
logits.copy_(logits_fp32.to(logits.dtype))
else:
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=indices)


class OutlinesVocabulary:
Expand Down
Loading