diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 798c136fc239..12ec5b0fcc66 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -3,6 +3,7 @@ import contextlib import importlib.metadata import os +import platform import random import threading from collections.abc import Callable, Collection @@ -67,6 +68,11 @@ T = TypeVar("T") +# Pin memory in non-WSL case. +# Logic duplicated here for now to avoid circular import. +PIN_MEMORY = "microsoft" not in " ".join(platform.uname()).lower() + + def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: return ( kv_cache_dtype.startswith("fp8") @@ -602,12 +608,12 @@ def create_kv_caches_with_random( def async_tensor_h2d( data: list, dtype: torch.dtype, - target_device: str | torch.device, - pin_memory: bool, + device: str | torch.device, + pin_memory: bool = PIN_MEMORY, ) -> torch.Tensor: """Asynchronously create a tensor and copy it from host to device.""" t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") - return t.to(device=target_device, non_blocking=True) + return t.to(device=device, non_blocking=True) def make_ndarray_with_pad( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 2de61a2b1f28..7c1a784888eb 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1105,7 +1105,8 @@ def build( paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[ prefill_start : num_reqs + 1 ] - paged_kv_indptr_prefill_gpu[0] = 0 + # Assign to slice to avoid cpu sync. + paged_kv_indptr_prefill_gpu[:1] = 0 torch.cumsum( num_blocks_per_req, dim=0, diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 1de6eb408ae2..0ce5baf1a341 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -48,12 +48,16 @@ flex_attention_compiled = torch.compile(flex_attention, fullgraph=True) -def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: - device = offsets.device - counts = offsets[1:] - offsets[:-1] - return torch.repeat_interleave( - torch.arange(len(counts), device=device, dtype=torch.int32), counts +def _offsets_to_doc_ids_tensor( + offsets_cpu: torch.Tensor, device: torch.device +) -> torch.Tensor: + # Build on CPU (so `repeat_interleave` doesn't force a GPU->CPU sync to + # learn the data-dependent output length) and upload non-blocking. + counts = offsets_cpu[1:] - offsets_cpu[:-1] + doc_ids = torch.repeat_interleave( + torch.arange(len(counts), dtype=torch.int32), counts ) + return doc_ids.to(device, non_blocking=True) def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): @@ -290,11 +294,13 @@ def unique_static_unsorted( keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N] # ── left-pack uniques into a fresh tensor ─────────────────────────── + # Route non-kept entries to a garbage slot at column N so we can do a + # single scatter rather than using torch.nonzero (which would force a + # GPU->CPU sync to enumerate kept positions). dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go - packed_flat = torch.full_like(x_flat, pad_val) - - rows, src_cols = torch.nonzero(keep, as_tuple=True) - packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols] + dest_pos = torch.where(keep, dest_pos, N) + packed_extended = torch.full((B, N + 1), pad_val, device=device, dtype=x_flat.dtype) + packed_flat = packed_extended.scatter_(1, dest_pos, x_flat)[:, :N] # ── restore original layout ───────────────────────────────────────── packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim) @@ -346,6 +352,9 @@ class FlexAttentionMetadata: num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor + # CPU-resident copy of query_start_loc used to derive doc_ids without a + # GPU->CPU sync from repeat_interleave's data-dependent output size. + query_start_loc_cpu: torch.Tensor max_seq_len: int seq_lens: torch.Tensor block_table: torch.Tensor @@ -452,12 +461,7 @@ def final_mask_mod( (is_valid, logical_q_idx, logical_kv_idx) = ( self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) ) - # Apply mask modification only for valid indices - return torch.where( - is_valid, - self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx), - False, - ) + return is_valid & self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx) return final_mask_mod @@ -469,7 +473,9 @@ def get_bidirectional_mask_mod(self) -> _mask_mod_signature: packed query sequences. """ # Create a lookup mapping from query indices -> request number - request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + request_lookup = _offsets_to_doc_ids_tensor( + self.query_start_loc_cpu, self.query_start_loc.device + ) def final_mask_mod( b: torch.Tensor, @@ -581,7 +587,9 @@ def get_transformed_score_mod(self) -> _score_mod_signature | None: return None # Create a lookup mapping from query indices -> request number - request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + request_lookup = _offsets_to_doc_ids_tensor( + self.query_start_loc_cpu, self.query_start_loc.device + ) user_score_mod = self.score_mod def transformed_score_mod( @@ -726,7 +734,9 @@ def __post_init__(self): assert self.prefix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet." # Create a lookup mapping from query indices -> request number - self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) + self.doc_ids = _offsets_to_doc_ids_tensor( + self.query_start_loc_cpu, self.query_start_loc.device + ) self.doc_ids = copy_to_persistent(self.persistent_doc_ids, self.doc_ids) self.num_blocks = self.total_cache_tokens // self.block_size @@ -807,6 +817,7 @@ def build( max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping @@ -871,6 +882,7 @@ def build( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table_tensor, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index eec53032288d..716dfcde592f 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -9,6 +9,7 @@ from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import async_tensor_h2d from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -270,16 +271,20 @@ def _build_chunk_metadata_tensors( num_prefills = common.num_prefills num_decode_tokens = common.num_decode_tokens - num_computed_tokens_cpu = ( - common_attn_metadata.compute_num_computed_tokens().cpu() - ) - num_computed_tokens_p_cpu = num_computed_tokens_cpu[ - num_reqs - num_prefills : num_reqs - ] + # Derive prefill context lengths from CPU data only. + # `seq_lens_cpu_upper_bound` is precise for prefill rows in all modes + # (including async spec decode), so this avoids the D2H sync that + # `compute_num_computed_tokens().cpu()` would force. + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound + assert seq_lens_cpu is not None query_start_loc_p_cpu = ( common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] - num_decode_tokens ) + prefill_query_lens_cpu = query_start_loc_p_cpu[1:] - query_start_loc_p_cpu[:-1] + num_computed_tokens_p_cpu = ( + seq_lens_cpu[num_reqs - num_prefills : num_reqs] - prefill_query_lens_cpu + ) cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata( chunk_size, @@ -289,20 +294,14 @@ def _build_chunk_metadata_tensors( ) device = common_attn_metadata.query_start_loc.device - cu_chunk_seqlen_p = torch.as_tensor( - cu_chunk_seqlen, - device=device, - dtype=torch.int32, - ) - seq_idx_p = torch.as_tensor( - seq_idx, - device=device, - dtype=torch.int32, + # Build on pinned CPU and upload non-blocking to avoid the synchronous + # H2D copy that `torch.as_tensor(list, device=cuda)` would force. + cu_chunk_seqlen_p = async_tensor_h2d( + cu_chunk_seqlen, dtype=torch.int32, device=device ) - last_chunk_indices_p = torch.as_tensor( - last_chunk_indices, - device=device, - dtype=torch.int32, + seq_idx_p = async_tensor_h2d(seq_idx, dtype=torch.int32, device=device) + last_chunk_indices_p = async_tensor_h2d( + last_chunk_indices, dtype=torch.int32, device=device ) return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index ceee8d5499ea..af9c91d11ee2 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -90,6 +90,14 @@ class TreeAttentionMetadata: num_prefills: int = 0 num_decodes: int = 0 + # Precomputed (on CPU in the builder) max_query_len and max_seq_len for + # the prefill-only and decode-only sub-batches. Used by the properties + # below to avoid a GPU->CPU sync via `.max().item()` on every forward. + max_query_len_prefill: int = 0 + max_seq_len_prefill: int = 0 + max_query_len_decode: int = 0 + max_seq_len_decode: int = 0 + tree_attn_bias: torch.Tensor | None = None # Cached Prefill/decode metadata. @@ -107,14 +115,13 @@ def prefill_metadata(self) -> "TreeAttentionMetadata | None": return self._cached_prefill_metadata q_start_loc = self.query_start_loc[self.num_decodes :] - q_seqlens = torch.diff(q_start_loc) kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, - max_query_len=int(q_seqlens.max().item()), + max_query_len=self.max_query_len_prefill, query_start_loc=q_start_loc - q_start_loc[0], - max_seq_len=int(kv_seqlens.max().item()), + max_seq_len=self.max_seq_len_prefill, seq_lens=kv_seqlens, block_table=self.block_table[self.num_decodes :], slot_mapping=self.slot_mapping[self.num_decode_tokens :], @@ -132,14 +139,13 @@ def decode_metadata(self) -> "TreeAttentionMetadata | None": return self._cached_decode_metadata q_start_loc = self.query_start_loc[: self.num_decodes + 1] - q_seqlens = torch.diff(q_start_loc) kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_decode_tokens, - max_query_len=int(q_seqlens.max().item()), + max_query_len=self.max_query_len_decode, query_start_loc=q_start_loc, - max_seq_len=int(kv_seqlens.max().item()), + max_seq_len=self.max_seq_len_decode, seq_lens=kv_seqlens, block_table=self.block_table[: self.num_decodes], slot_mapping=self.slot_mapping[: self.num_decode_tokens], @@ -199,6 +205,42 @@ def build( block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + # Precompute prefill/decode sub-batch max_query_len / max_seq_len on + # CPU so the prefill_metadata / decode_metadata properties don't need + # a GPU->CPU sync via `.max().item()` on every forward. + # Prefer `seq_lens_cpu_upper_bound` over the (deprecated) + # `seq_lens_cpu` property: the upper bound is precise for prefill + # rows and optimistic-but-safe for decode rows (workspace sizing + # from `max()` is fine with an over-estimate), and avoids the + # `seq_lens.to("cpu")` sync the property would fall through to in + # async-spec-decode mode. The draft-attention path (eagle + # speculator) doesn't populate it; fall back to the batch-wide + # `max_seq_len` as a safe upper bound for both sub-batches. + q_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound + if num_prefills > 0: + q_seqlens_p = torch.diff(q_start_loc_cpu[num_decodes:]) + max_query_len_prefill = int(q_seqlens_p.max()) + max_seq_len_prefill = ( + int(seq_lens_cpu[num_decodes:].max()) + if seq_lens_cpu is not None + else max_seq_len + ) + else: + max_query_len_prefill = 0 + max_seq_len_prefill = 0 + if num_decodes > 0: + q_seqlens_d = torch.diff(q_start_loc_cpu[: num_decodes + 1]) + max_query_len_decode = int(q_seqlens_d.max()) + max_seq_len_decode = ( + int(seq_lens_cpu[:num_decodes].max()) + if seq_lens_cpu is not None + else max_seq_len + ) + else: + max_query_len_decode = 0 + max_seq_len_decode = 0 + return TreeAttentionMetadata( num_actual_tokens=num_actual_tokens, num_prefill_tokens=num_prefill_tokens, @@ -211,6 +253,10 @@ def build( seq_lens=kv_seqlens, block_table=block_table, slot_mapping=slot_mapping, + max_query_len_prefill=max_query_len_prefill, + max_seq_len_prefill=max_seq_len_prefill, + max_query_len_decode=max_query_len_decode, + max_seq_len_decode=max_seq_len_decode, tree_attn_bias=self.tree_attn_bias, ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index f254d95a414c..0aeb33ae6fd9 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,7 +18,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import next_power_of_2 -from vllm.utils.torch_utils import is_quantized_kv_cache +from vllm.utils.torch_utils import async_tensor_h2d, is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -117,10 +117,9 @@ def compute_mm_prefix_range_tensor( for r in range_lists: padded_r = list(r) + [(0, 0)] * (max_ranges - len(r)) padded.append(padded_r) - # Create tensor with efficient H2D transfer - return torch.tensor(padded, dtype=torch.int32, device=device).view( - num_seqs, max_ranges, 2 - ) + # Build on pinned CPU memory so the H2D transfer is non-blocking. + padded = async_tensor_h2d(padded, dtype=torch.int32, device=device) + return padded.view(num_seqs, max_ranges, 2) class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index af2d0fb0830f..53684b4360f7 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -187,6 +187,10 @@ class TurboQuantMetadata(AttentionMetadata): is_prefill: bool = False num_decodes: int = 0 # number of decode requests (first in batch) num_decode_tokens: int = 0 # tokens from decode requests + # CPU-resident copies used by the prefill path for per-request iteration + # without per-step D2H syncs. + query_start_loc_cpu: torch.Tensor | None = None + seq_lens_cpu: torch.Tensor | None = None class TurboQuantMetadataBuilder(AttentionMetadataBuilder[TurboQuantMetadata]): @@ -230,6 +234,8 @@ def build(self, common_prefix_len, common_attn_metadata, fast_build=False): is_prefill=(cam.max_query_len > 1), num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, + query_start_loc_cpu=cam.query_start_loc_cpu, + seq_lens_cpu=cam.seq_lens_cpu_upper_bound, ) @@ -474,11 +480,21 @@ def forward( # first-chunk prefills. Using full-batch max_seq_len breaks # this because decode requests inflate max_seq_len. prefill_seq_lens = attn_metadata.seq_lens[num_decodes:] - # Use CPU-side max to avoid GPU→CPU sync from .item() - prefill_max_seq = max(attn_metadata.seq_lens[num_decodes:].tolist()) + # Use the CPU-resident `seq_lens` upper-bound from the metadata + # (populated in the builder) to compute the prefill sub-batch + # max without a GPU→CPU sync. + if attn_metadata.seq_lens_cpu is not None: + prefill_max_seq = int(attn_metadata.seq_lens_cpu[num_decodes:].max()) + else: + prefill_max_seq = attn_metadata.max_seq_len prefill_qsl = ( attn_metadata.query_start_loc[num_decodes:] - num_decode_tokens ) + prefill_qsl_cpu = None + if attn_metadata.query_start_loc_cpu is not None: + prefill_qsl_cpu = ( + attn_metadata.query_start_loc_cpu[num_decodes:] - num_decode_tokens + ) prefill_meta = TurboQuantMetadata( seq_lens=prefill_seq_lens, slot_mapping=attn_metadata.slot_mapping[num_decode_tokens:N], @@ -488,6 +504,10 @@ def forward( max_query_len=attn_metadata.max_query_len, max_seq_len=prefill_max_seq, is_prefill=True, + query_start_loc_cpu=prefill_qsl_cpu, + seq_lens_cpu=attn_metadata.seq_lens_cpu[num_decodes:] + if attn_metadata.seq_lens_cpu is not None + else None, ) k = key[:N].view(N, self.num_kv_heads, self.head_size) v = value[:N].view(N, self.num_kv_heads, self.head_size) @@ -578,10 +598,16 @@ def _prefill_attention( output = torch.zeros(N, Hq, D, device=query.device, dtype=query.dtype) - # Convert to Python lists once (single CPU-GPU sync) instead of - # per-request .item() calls that each force a sync. - qsl = query_start_loc.tolist() - seq_lens_list = attn_metadata.seq_lens.tolist() + # Prefer the CPU-resident copies from the metadata if populated — + # otherwise `.tolist()` on GPU tensors forces a synchronizing copy. + if attn_metadata.query_start_loc_cpu is not None: + qsl = attn_metadata.query_start_loc_cpu.tolist() + else: + qsl = query_start_loc.tolist() + if attn_metadata.seq_lens_cpu is not None: + seq_lens_list = attn_metadata.seq_lens_cpu.tolist() + else: + seq_lens_list = attn_metadata.seq_lens.tolist() # Pre-allocate cu_seqlens for single-request flash_attn calls # to avoid per-request host→device tensor creation. @@ -612,7 +638,8 @@ def _prefill_attention( if q_len == seq_len: # First-chunk prefill: all K/V are in the current batch. if _HAS_FLASH_ATTN: - self._cu_2[1] = q_len + # Assign to slice to avoid gpu/cpu sync. + self._cu_2[1:2] = q_len cu = self._cu_2 out = self._flash_attn_varlen( q=q_seq, @@ -791,8 +818,9 @@ def _continuation_prefill( if not hasattr(self, "_cu_2_q"): self._cu_2_q = torch.zeros(2, device=device, dtype=torch.int32) self._cu_2_k = torch.zeros(2, device=device, dtype=torch.int32) - self._cu_2_q[1] = q_len - self._cu_2_k[1] = seq_len + # Assigning to slice uses fill_ which avoids cpu/gpu sync. + self._cu_2_q[1:2] = q_len + self._cu_2_k[1:2] = seq_len cu_seqlens_q = self._cu_2_q cu_seqlens_k = self._cu_2_k return self._flash_attn_varlen( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 54ebd088b95e..43cbcfec1844 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -332,8 +332,10 @@ def make_local_attention_virtual_batches( # regression when using numpy arrays (batch and block indices) to index into # torch tensor (block_table). As a workaround, convert numpy arrays to torch # tensor first, which recovers perf. - batch_indices_torch = torch.from_numpy(batch_indices) - block_indices_torch = torch.from_numpy(block_indices) + # Upload the index tensors to the block_table's device up-front so that the + # fancy indexing below doesn't implicitly force a synchronous H2D copy. + batch_indices_torch = torch.from_numpy(batch_indices).to(device, non_blocking=True) + block_indices_torch = torch.from_numpy(block_indices).to(device, non_blocking=True) # Save as a lambda so we can return this for update_block_table make_block_table = lambda block_table: block_table[ @@ -391,7 +393,16 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( # Figure out how many tokens are in each request # num_decode_tokens: [1, 2, 1] - num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + # Avoid `torch.bincount` here — on CUDA it forces a sync to determine + # the output size (even with `minlength`, the kernel must confirm no + # value exceeds the bound). `scatter_add_` into a preallocated buffer + # is equivalent and stays async. + num_decode_tokens = torch.zeros( + num_reqs, dtype=request_ids.dtype, device=request_ids.device + ) + num_decode_tokens.scatter_add_( + 0, request_ids.to(num_decode_tokens.dtype), torch.ones_like(request_ids) + ) # Calculate new query_start_loc with tokens in generation_indices # decode_query_start_loc: [0, 1, 3, 4] @@ -399,7 +410,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype ) - decode_query_start_loc[0] = 0 + decode_query_start_loc[:1].fill_(0) # Avoid sync from scalar assignment. decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) decode_max_query_len = int(num_decode_tokens.max().item()) total_num_decode_tokens = int(num_decode_tokens.sum().item()) diff --git a/vllm/v1/worker/gpu/buffer_utils.py b/vllm/v1/worker/gpu/buffer_utils.py index a653c262556c..5963790a7792 100644 --- a/vllm/v1/worker/gpu/buffer_utils.py +++ b/vllm/v1/worker/gpu/buffer_utils.py @@ -167,7 +167,7 @@ def apply_write(self) -> None: # Special handling for write_contents write_contents = async_tensor_h2d( - self._staged_write_contents, self.dtype, self.device, pin_memory=True + self._staged_write_contents, self.dtype, self.device ) # Write diffs to the GPU buffer diff --git a/vllm/v1/worker/gpu/sample/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py index 04adf9369233..057517479b51 100644 --- a/vllm/v1/worker/gpu/sample/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -58,8 +58,7 @@ def apply_staged_writes(self) -> None: idx_mapping = async_tensor_h2d( self._new_penalties_reqs, dtype=torch.int32, - target_device=self.device, - pin_memory=True, + device=self.device, ) prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs]