-
Notifications
You must be signed in to change notification settings - Fork 895
feat(gdn): separate input and output pool indices #2905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1107,7 +1107,8 @@ def gdn_decode_bf16state_mtp_kernel( | |||||||||
| v: cute.Tensor, # [B, T, HV, V] | ||||||||||
| b: cute.Tensor, # [B, T, HV] | ||||||||||
| o: cute.Tensor, # [B, T, HV, V] - output | ||||||||||
| h0_indices: cute.Tensor, # [B] - initial state indices | ||||||||||
| h0_indices: cute.Tensor, # [B] - initial state indices (read) | ||||||||||
| h0_out_indices: cute.Tensor, # [B] - output state indices (write) | ||||||||||
| softplus_beta: cutlass.Constexpr[float], | ||||||||||
| softplus_threshold: cutlass.Constexpr[float], | ||||||||||
| scale: cutlass.Constexpr[float], | ||||||||||
|
|
@@ -1320,6 +1321,8 @@ def gdn_decode_bf16state_mtp_kernel( | |||||||||
|
|
||||||||||
| # Each group handles tile_v/num_groups V rows, 8 at a time (ILP=8) | ||||||||||
| flat_state_idx = cache_idx * HV + i_hv | ||||||||||
| write_cache_idx = h0_out_indices[i_n] | ||||||||||
| flat_write_idx = write_cache_idx * HV + i_hv | ||||||||||
| rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups | ||||||||||
| eighth_rows: cutlass.Constexpr[int] = rows_per_group // MTP_ILP_ROWS | ||||||||||
|
|
||||||||||
|
|
@@ -1975,14 +1978,38 @@ def gdn_decode_bf16state_mtp_kernel( | |||||||||
| r_hb5[i] = cutlass.BFloat16(r_h[5, i]) | ||||||||||
| r_hb6[i] = cutlass.BFloat16(r_h[6, i]) | ||||||||||
| r_hb7[i] = cutlass.BFloat16(r_h[7, i]) | ||||||||||
| cute.autovec_copy(r_hb0, ht0) | ||||||||||
| cute.autovec_copy(r_hb1, ht1) | ||||||||||
| cute.autovec_copy(r_hb2, ht2) | ||||||||||
| cute.autovec_copy(r_hb3, ht3) | ||||||||||
| cute.autovec_copy(r_hb4, ht4) | ||||||||||
| cute.autovec_copy(r_hb5, ht5) | ||||||||||
| cute.autovec_copy(r_hb6, ht6) | ||||||||||
| cute.autovec_copy(r_hb7, ht7) | ||||||||||
| wt0 = cute.local_tile( | ||||||||||
| h0_source, (1, 1, vec_size), (flat_write_idx, v0, lane_in_group) | ||||||||||
| ) | ||||||||||
| wt1 = cute.local_tile( | ||||||||||
| h0_source, (1, 1, vec_size), (flat_write_idx, v1, lane_in_group) | ||||||||||
| ) | ||||||||||
| wt2 = cute.local_tile( | ||||||||||
| h0_source, (1, 1, vec_size), (flat_write_idx, v2, lane_in_group) | ||||||||||
| ) | ||||||||||
| wt3 = cute.local_tile( | ||||||||||
| h0_source, (1, 1, vec_size), (flat_write_idx, v3, lane_in_group) | ||||||||||
| ) | ||||||||||
| wt4 = cute.local_tile( | ||||||||||
| h0_source, (1, 1, vec_size), (flat_write_idx, v4, lane_in_group) | ||||||||||
| ) | ||||||||||
| wt5 = cute.local_tile( | ||||||||||
| h0_source, (1, 1, vec_size), (flat_write_idx, v5, lane_in_group) | ||||||||||
| ) | ||||||||||
| wt6 = cute.local_tile( | ||||||||||
| h0_source, (1, 1, vec_size), (flat_write_idx, v6, lane_in_group) | ||||||||||
| ) | ||||||||||
| wt7 = cute.local_tile( | ||||||||||
| h0_source, (1, 1, vec_size), (flat_write_idx, v7, lane_in_group) | ||||||||||
| ) | ||||||||||
| cute.autovec_copy(r_hb0, wt0) | ||||||||||
| cute.autovec_copy(r_hb1, wt1) | ||||||||||
| cute.autovec_copy(r_hb2, wt2) | ||||||||||
| cute.autovec_copy(r_hb3, wt3) | ||||||||||
| cute.autovec_copy(r_hb4, wt4) | ||||||||||
| cute.autovec_copy(r_hb5, wt5) | ||||||||||
| cute.autovec_copy(r_hb6, wt6) | ||||||||||
| cute.autovec_copy(r_hb7, wt7) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| # ============================================================================== | ||||||||||
|
|
@@ -2003,6 +2030,7 @@ def run_gdn_decode_bf16state_mtp( | |||||||||
| b: cute.Tensor, | ||||||||||
| o: cute.Tensor, | ||||||||||
| h0_indices: cute.Tensor, | ||||||||||
| h0_out_indices: cute.Tensor, | ||||||||||
| softplus_beta: cutlass.Constexpr[float], | ||||||||||
| softplus_threshold: cutlass.Constexpr[float], | ||||||||||
| scale: cutlass.Constexpr[float], | ||||||||||
|
|
@@ -2056,6 +2084,7 @@ def run_gdn_decode_bf16state_mtp( | |||||||||
| b, | ||||||||||
| o, | ||||||||||
| h0_indices, | ||||||||||
| h0_out_indices, | ||||||||||
| softplus_beta, | ||||||||||
| softplus_threshold, | ||||||||||
| scale, | ||||||||||
|
|
@@ -2467,6 +2496,7 @@ def gated_delta_rule_mtp( | |||||||||
| b: Optional[torch.Tensor] = None, | ||||||||||
| initial_state_source: Optional[torch.Tensor] = None, | ||||||||||
| initial_state_indices: Optional[torch.Tensor] = None, | ||||||||||
| output_state_indices: Optional[torch.Tensor] = None, | ||||||||||
| intermediate_states_buffer: Optional[torch.Tensor] = None, | ||||||||||
| disable_state_update: bool = False, | ||||||||||
| use_qk_l2norm_in_kernel: bool = True, | ||||||||||
|
|
@@ -2487,7 +2517,9 @@ def gated_delta_rule_mtp( | |||||||||
| v: [B, T, HV, V] bf16 | ||||||||||
| b: [B, T, HV] bf16 | ||||||||||
| initial_state_source: [pool_size, HV, V, K] bf16 | ||||||||||
| initial_state_indices: [B] int32 - indices into state pool | ||||||||||
| initial_state_indices: [B] int32 - indices into state pool (read) | ||||||||||
| output_state_indices: Optional [B] int32 - indices for writing updated state. | ||||||||||
| Defaults to initial_state_indices when None. | ||||||||||
| intermediate_states_buffer: Optional [pool_size, T, HV, V, K] bf16 | ||||||||||
| disable_state_update: bool - if True, don't update initial state | ||||||||||
| scale: Optional, default 1/sqrt(K) | ||||||||||
|
|
@@ -2514,6 +2546,12 @@ def gated_delta_rule_mtp( | |||||||||
| if initial_state_indices is None: | ||||||||||
| initial_state_indices = torch.arange(B, dtype=torch.int32, device=q.device) | ||||||||||
|
|
||||||||||
| # Default output indices to read indices | ||||||||||
| if output_state_indices is None: | ||||||||||
| output_state_indices = initial_state_indices | ||||||||||
| elif output_state_indices.dtype != torch.int32: | ||||||||||
| output_state_indices = output_state_indices.to(torch.int32) | ||||||||||
|
Comment on lines
+2549
to
+2553
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Preserve padding/null-buffer semantics when defaulting This regresses the existing BF16 negative-index path: padded reads still come in as π Minimal fix- if output_state_indices is None:
- output_state_indices = initial_state_indices
- elif output_state_indices.dtype != torch.int32:
- output_state_indices = output_state_indices.to(torch.int32)
+ if output_state_indices is None:
+ # Preserve the existing slot-0 null-buffer behavior for padded rows.
+ output_state_indices = initial_state_indices.clamp_min(0)
+ if output_state_indices.dtype != torch.int32:
+ output_state_indices = output_state_indices.to(torch.int32)π€ Prompt for AI Agents |
||||||||||
|
|
||||||||||
| if output is None: | ||||||||||
| output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) | ||||||||||
|
|
||||||||||
|
|
@@ -2552,6 +2590,7 @@ def gated_delta_rule_mtp( | |||||||||
| dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) | ||||||||||
| o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) | ||||||||||
| h0_idx_ = from_dlpack(initial_state_indices, assumed_align=32, enable_tvm_ffi=True) | ||||||||||
| h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix formatting to pass pre-commit checks. The pipeline failure indicates this line needs reformatting per π§ Apply ruff formatting- h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
+ h0_out_idx_ = from_dlpack(
+ output_state_indices, assumed_align=32, enable_tvm_ffi=True
+ )π Committable suggestion
Suggested change
π§° Toolsπͺ GitHub Actions: pre-commit[error] 2590-2593: pre-commit failed: ruff format (hook id: ruff-format) reformatted files. Diff shows formatting change in gated_delta_rule_mtp() for h0_out_idx_ = from_dlpack(output_state_indices, ...). π€ Prompt for AI Agents |
||||||||||
|
|
||||||||||
| stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) | ||||||||||
|
|
||||||||||
|
|
@@ -2590,6 +2629,7 @@ def gated_delta_rule_mtp( | |||||||||
| b_, | ||||||||||
| o_, | ||||||||||
| h0_idx_, | ||||||||||
| h0_out_idx_, | ||||||||||
| softplus_beta, | ||||||||||
| softplus_threshold, | ||||||||||
| scale, | ||||||||||
|
|
@@ -2620,6 +2660,7 @@ def gated_delta_rule_mtp( | |||||||||
| b_, | ||||||||||
| o_, | ||||||||||
| h0_idx_, | ||||||||||
| h0_out_idx_, | ||||||||||
| stream, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -55,7 +55,8 @@ def gdn_decode_kernel_small_batch_pretranspose( | |||||||||||||||||
| v: cute.Tensor, # [B, T, HV, V] | ||||||||||||||||||
| b: cute.Tensor, # [B, T, HV] | ||||||||||||||||||
| o: cute.Tensor, # [B, T, HV, V] - output | ||||||||||||||||||
| h0_indices: cute.Tensor, # [B] - initial state indices | ||||||||||||||||||
| h0_indices: cute.Tensor, # [B] - initial state indices (read) | ||||||||||||||||||
| h0_out_indices: cute.Tensor, # [B] - output state indices (write) | ||||||||||||||||||
| cu_seqlens: cute.Tensor, # [B+1] - cumulative sequence lengths (for varlen) | ||||||||||||||||||
| softplus_beta: cutlass.Constexpr[float], | ||||||||||||||||||
| softplus_threshold: cutlass.Constexpr[float], | ||||||||||||||||||
|
|
@@ -134,16 +135,18 @@ def gdn_decode_kernel_small_batch_pretranspose( | |||||||||||||||||
| # Compute state index: use pool indexing if enabled. | ||||||||||||||||||
| if cutlass.const_expr(use_pool_indexing): | ||||||||||||||||||
| pool_idx = h0_indices[i_n] | ||||||||||||||||||
| out_pool_idx = h0_out_indices[i_n] | ||||||||||||||||||
| else: | ||||||||||||||||||
| pool_idx = 0 | ||||||||||||||||||
| out_pool_idx = 0 | ||||||||||||||||||
|
|
||||||||||||||||||
| if pool_idx >= 0: | ||||||||||||||||||
| # Get current state slice. | ||||||||||||||||||
| # Get current batch | ||||||||||||||||||
| if cutlass.const_expr(use_pool_indexing): | ||||||||||||||||||
| # h0_source layout: [pool_size, HV, V, K] (supports non-contiguous page stride) | ||||||||||||||||||
| gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K) | ||||||||||||||||||
| gDst = cute.local_tile( | ||||||||||||||||||
| h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0) | ||||||||||||||||||
| h0_source, (1, 1, TILE_V, TILE_K), (out_pool_idx, i_hv, None, 0) | ||||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| # h0_source layout: [B*HV, V, K] | ||||||||||||||||||
|
|
@@ -307,7 +310,7 @@ def gdn_decode_kernel_small_batch_pretranspose( | |||||||||||||||||
| r_h[i] += r_k[i] * v_new | ||||||||||||||||||
| sum_hq += r_h[i] * r_q[i] | ||||||||||||||||||
|
|
||||||||||||||||||
| # Write h back to state. | ||||||||||||||||||
| # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) | ||||||||||||||||||
| if cutlass.const_expr(use_pool_indexing): | ||||||||||||||||||
| gDst_tile = cute.local_tile( | ||||||||||||||||||
| gDst, | ||||||||||||||||||
|
|
@@ -361,7 +364,8 @@ def gdn_decode_kernel_big_batch_pretranspose( | |||||||||||||||||
| v: cute.Tensor, # [B, T, HV, V] | ||||||||||||||||||
| b: cute.Tensor, # [B, T, HV] | ||||||||||||||||||
| o: cute.Tensor, # [B, T, HV, V] - output | ||||||||||||||||||
| h0_indices: cute.Tensor, # [B] - initial state indices | ||||||||||||||||||
| h0_indices: cute.Tensor, # [B] - initial state indices (read) | ||||||||||||||||||
| h0_out_indices: cute.Tensor, # [B] - output state indices (write) | ||||||||||||||||||
| cu_seqlens: cute.Tensor, # [B+1] - cumulative sequence lengths (for varlen) | ||||||||||||||||||
| softplus_beta: cutlass.Constexpr[float], | ||||||||||||||||||
| softplus_threshold: cutlass.Constexpr[float], | ||||||||||||||||||
|
|
@@ -436,16 +440,18 @@ def gdn_decode_kernel_big_batch_pretranspose( | |||||||||||||||||
| # Compute state index: use pool indexing if enabled. | ||||||||||||||||||
| if cutlass.const_expr(use_pool_indexing): | ||||||||||||||||||
| pool_idx = h0_indices[i_n] | ||||||||||||||||||
| out_pool_idx = h0_out_indices[i_n] | ||||||||||||||||||
| else: | ||||||||||||||||||
| pool_idx = 0 | ||||||||||||||||||
| out_pool_idx = 0 | ||||||||||||||||||
|
|
||||||||||||||||||
| if pool_idx >= 0: | ||||||||||||||||||
| # Get current state slice. | ||||||||||||||||||
| if cutlass.const_expr(use_pool_indexing): | ||||||||||||||||||
| # h0_source layout: [pool_size, HV, V, K] (supports non-contiguous page stride) | ||||||||||||||||||
| gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K) | ||||||||||||||||||
| gDst = cute.local_tile( | ||||||||||||||||||
| h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0) | ||||||||||||||||||
| h0_source, (1, 1, TILE_V, TILE_K), (out_pool_idx, i_hv, None, 0) | ||||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| # h0_source layout: [B*HV, V, K] | ||||||||||||||||||
|
|
@@ -657,6 +663,7 @@ def run_gdn_decode_kernel_small_batch_pretranspose( | |||||||||||||||||
| b: cute.Tensor, | ||||||||||||||||||
| o: cute.Tensor, | ||||||||||||||||||
| h0_indices: cute.Tensor, | ||||||||||||||||||
| h0_out_indices: cute.Tensor, | ||||||||||||||||||
| cu_seqlens: cute.Tensor, | ||||||||||||||||||
| softplus_beta: cutlass.Constexpr[float], | ||||||||||||||||||
| softplus_threshold: cutlass.Constexpr[float], | ||||||||||||||||||
|
|
@@ -734,6 +741,7 @@ def run_gdn_decode_kernel_small_batch_pretranspose( | |||||||||||||||||
| b, | ||||||||||||||||||
| o, | ||||||||||||||||||
| h0_indices, | ||||||||||||||||||
| h0_out_indices, | ||||||||||||||||||
| cu_seqlens, | ||||||||||||||||||
| softplus_beta, | ||||||||||||||||||
| softplus_threshold, | ||||||||||||||||||
|
|
@@ -768,6 +776,7 @@ def run_gdn_decode_kernel_big_batch_pretranspose( | |||||||||||||||||
| b: cute.Tensor, | ||||||||||||||||||
| o: cute.Tensor, | ||||||||||||||||||
| h0_indices: cute.Tensor, | ||||||||||||||||||
| h0_out_indices: cute.Tensor, | ||||||||||||||||||
| cu_seqlens: cute.Tensor, | ||||||||||||||||||
| softplus_beta: cutlass.Constexpr[float], | ||||||||||||||||||
| softplus_threshold: cutlass.Constexpr[float], | ||||||||||||||||||
|
|
@@ -840,6 +849,7 @@ def run_gdn_decode_kernel_big_batch_pretranspose( | |||||||||||||||||
| b, | ||||||||||||||||||
| o, | ||||||||||||||||||
| h0_indices, | ||||||||||||||||||
| h0_out_indices, | ||||||||||||||||||
| cu_seqlens, | ||||||||||||||||||
| softplus_beta, | ||||||||||||||||||
| softplus_threshold, | ||||||||||||||||||
|
|
@@ -910,6 +920,7 @@ def run_pretranspose_decode( | |||||||||||||||||
| use_qk_l2norm: bool, | ||||||||||||||||||
| use_pool_indexing: bool = False, | ||||||||||||||||||
| initial_state_indices: Optional[torch.Tensor] = None, | ||||||||||||||||||
| output_state_indices: Optional[torch.Tensor] = None, | ||||||||||||||||||
| ): | ||||||||||||||||||
| """Compile and execute the pretranspose decode kernel. | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -924,6 +935,8 @@ def run_pretranspose_decode( | |||||||||||||||||
| use_pool_indexing: Whether to use pool-based indirect state indexing. | ||||||||||||||||||
| initial_state_indices: Int32 indices into state pool, shape [B]. | ||||||||||||||||||
| Negative values indicate padding (kernel writes zeros). | ||||||||||||||||||
| output_state_indices: Optional int32 indices for write destination, shape [B]. | ||||||||||||||||||
| When None, writes go to the same slot as initial_state_indices. | ||||||||||||||||||
| """ | ||||||||||||||||||
| # Compile kernel with TVM FFI (cached) | ||||||||||||||||||
| if use_pool_indexing: | ||||||||||||||||||
|
|
@@ -959,6 +972,11 @@ def run_pretranspose_decode( | |||||||||||||||||
| h0_indices = initial_state_indices.to(torch.int32) | ||||||||||||||||||
| else: | ||||||||||||||||||
| h0_indices = cache["h0_indices"] | ||||||||||||||||||
| # Resolve output indices: default to same as read indices | ||||||||||||||||||
| if use_pool_indexing and output_state_indices is not None: | ||||||||||||||||||
| h0_out_indices = output_state_indices.to(torch.int32) | ||||||||||||||||||
| else: | ||||||||||||||||||
| h0_out_indices = h0_indices | ||||||||||||||||||
|
Comment on lines
+976
to
+979
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The You can simplify this logic for better readability.
Suggested change
|
||||||||||||||||||
| cu_seqlens = cache["cu_seqlens"] | ||||||||||||||||||
|
|
||||||||||||||||||
| if "compiled" not in cache: | ||||||||||||||||||
|
|
@@ -976,6 +994,7 @@ def run_pretranspose_decode( | |||||||||||||||||
| b_tensor = from_dlpack(b, assumed_align=16) | ||||||||||||||||||
| o_tensor = from_dlpack(output, assumed_align=16) | ||||||||||||||||||
| h0_indices_tensor = from_dlpack(h0_indices, assumed_align=16) | ||||||||||||||||||
| h0_out_indices_tensor = from_dlpack(h0_out_indices, assumed_align=16) | ||||||||||||||||||
| cu_seqlens_tensor = from_dlpack(cu_seqlens, assumed_align=16) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Always use 8-CTA architecture (benchmarks show it's better for all batch sizes) | ||||||||||||||||||
|
|
@@ -994,6 +1013,7 @@ def run_pretranspose_decode( | |||||||||||||||||
| b_tensor, | ||||||||||||||||||
| o_tensor, | ||||||||||||||||||
| h0_indices_tensor, | ||||||||||||||||||
| h0_out_indices_tensor, | ||||||||||||||||||
| cu_seqlens_tensor, | ||||||||||||||||||
| softplus_beta=1.0, | ||||||||||||||||||
| softplus_threshold=20.0, | ||||||||||||||||||
|
|
@@ -1018,5 +1038,17 @@ def run_pretranspose_decode( | |||||||||||||||||
| # Run kernel directly with PyTorch tensors (no from_dlpack needed) | ||||||||||||||||||
| stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) | ||||||||||||||||||
| cache["compiled"]( | ||||||||||||||||||
| h0_source, A_log, a, dt_bias, q, k, v, b, output, h0_indices, cu_seqlens, stream | ||||||||||||||||||
| h0_source, | ||||||||||||||||||
| A_log, | ||||||||||||||||||
| a, | ||||||||||||||||||
| dt_bias, | ||||||||||||||||||
| q, | ||||||||||||||||||
| k, | ||||||||||||||||||
| v, | ||||||||||||||||||
| b, | ||||||||||||||||||
| output, | ||||||||||||||||||
| h0_indices, | ||||||||||||||||||
| h0_out_indices, | ||||||||||||||||||
| cu_seqlens, | ||||||||||||||||||
| stream, | ||||||||||||||||||
| ) | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reject in-place remaps that alias another batch item's source slot.
output_state_indicesstill writes back into the sameinitial_statebuffer during the same kernel launch. If two batch items target the same write slot, or one item writes a slot another item is still reading viainitial_state_indices, the final state becomes CTA-order dependent and no longer matches gatherβcomputeβscatter semantics. Please either validate a safe mapping here or route overlapping remaps through a staged fallback.π€ Prompt for AI Agents
Validate
output_state_indicesagainst the pool before dispatch.The new arg is only shape/dtype-checked. A CPU tensor here will fail late, and a negative or
>= pool_sizewrite index can either become an out-of-bounds store on the float32 pretranspose path or silently alias slot0on the bf16 path. Please reject non-local or out-of-range write indices here unless you want explicit write-side padding semantics.π‘ Suggested guard
if output_state_indices is not None: assert use_pool, ( "output_state_indices can only be used with initial_state (pool mode)" ) assert output_state_indices.shape == (B,), ( f"Expected output_state_indices shape [{B}], " f"got {output_state_indices.shape}" ) assert output_state_indices.dtype in (torch.int32, torch.int64), ( f"output_state_indices must be int32 or int64, " f"got {output_state_indices.dtype}" ) + assert output_state_indices.device == initial_state.device, ( + "output_state_indices must be on the same device as initial_state" + ) + pool_size = int(initial_state.shape[0]) + in_range = (output_state_indices >= 0) & (output_state_indices < pool_size) + assert in_range.all().item(), ( + f"output_state_indices must be in [0, {pool_size})" + )π€ Prompt for AI Agents