diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 0d31363cb5b4..9fc1accb2e2c 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils.import_utils import LazyLoader +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput if TYPE_CHECKING: @@ -106,28 +107,33 @@ def apply_grammar_bitmask( # since the bitmask is already aligned with the logits. skip_out_indices = len(out_indices) == logits.shape[0] - index_tensor = None - if not skip_out_indices: - # xgrammar expects a python list of indices but it will actually work with - # a tensor. If we copy the tensor ourselves here we can do it in a non_blocking - # manner and there should be no cpu sync within xgrammar. - index_tensor = torch.tensor( - out_indices, dtype=torch.int32, device="cpu", pin_memory=True - ) - index_tensor = index_tensor.to(logits.device, non_blocking=True) + if not logits.is_cpu: + index_tensor = None + if not skip_out_indices: + # xgrammar expects a python list of indices but it will actually work with + # a tensor. If we copy the tensor ourselves here we can do it in a + # non_blocking manner and there should be no cpu sync within xgrammar. + pin_memory = is_pin_memory_available() + index_tensor = torch.tensor( + out_indices, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + index_tensor = index_tensor.to(logits.device, non_blocking=True) + + xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor) + return + # CPU case, use list for indices. + indices = None if skip_out_indices else out_indices # Handle dtype conversion for CPU (older xgrammar CPU kernels require float32) # See: https://github.com/vllm-project/vllm/issues/31901 - if logits.device.type == "cpu" and logits.dtype != torch.float32: + if logits.dtype != torch.float32: # Convert to float32, apply bitmask, then convert back - logits_float32 = logits.to(torch.float32) - xgr.apply_token_bitmask_inplace( - logits_float32, grammar_bitmask, indices=index_tensor - ) + logits_fp32 = logits.to(torch.float32) + xgr.apply_token_bitmask_inplace(logits_fp32, grammar_bitmask, indices=indices) # Copy the modified values back to the original tensor - logits.copy_(logits_float32.to(logits.dtype)) + logits.copy_(logits_fp32.to(logits.dtype)) else: - xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor) + xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=indices) class OutlinesVocabulary: