Skip to content
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = ""
VLLM_COMPUTE_PADDED_LOGITS_INDICES: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -991,6 +992,10 @@ def get_vllm_port() -> Optional[int]:
# The default value is "VLLM".
"VLLM_PROCESS_NAME_PREFIX":
lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"),

# Enable computing and propagating cudagraph padded logits indices
"VLLM_COMPUTE_PADDED_LOGITS_INDICES":
lambda: bool(int(os.getenv("VLLM_COMPUTE_PADDED_LOGITS_INDICES", "0"))),
}

# --8<-- [end:env-vars-definition]
Expand Down
3 changes: 3 additions & 0 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class ForwardContext:
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False
logits_indices_padded: Optional[torch.Tensor] = None


_forward_context: Optional[ForwardContext] = None
Expand All @@ -116,6 +117,7 @@ def set_forward_context(
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False,
logits_indices_padded: Optional[torch.Tensor] = None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Expand All @@ -141,6 +143,7 @@ def set_forward_context(
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
skip_cuda_graphs=skip_cuda_graphs,
logits_indices_padded=logits_indices_padded,
)

try:
Expand Down
26 changes: 26 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ def __init__(
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}

self.logits_indices = None
if envs.VLLM_COMPUTE_PADDED_LOGITS_INDICES:
self.logits_indices = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)

def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
Expand Down Expand Up @@ -1364,6 +1370,7 @@ def execute_model(
spec_decode_metadata, num_scheduled_tokens_np,
spec_decode_common_attn_metadata) = (
self._prepare_inputs(scheduler_output))

num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
Expand Down Expand Up @@ -1436,6 +1443,24 @@ def execute_model(
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs

logits_indices_padded = None
if envs.VLLM_COMPUTE_PADDED_LOGITS_INDICES:
assert self.logits_indices is not None
num_logits = logits_indices.shape[0]
assert num_logits > 0
self.logits_indices[:num_logits].copy_(logits_indices)
# Ensure we keep duplicates instead of zeros
self.logits_indices[num_logits:].fill_(logits_indices[-1].item())
if (self.use_cuda_graph
and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_logits_padded = self.vllm_config.pad_for_cudagraph(
num_logits)
else:
num_logits_padded = num_logits
logits_indices_padded = self.logits_indices[:num_logits_padded]

# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
Expand All @@ -1444,6 +1469,7 @@ def execute_model(
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
logits_indices_padded=logits_indices_padded,
):
self.maybe_setup_kv_connector(scheduler_output)

Expand Down