cp: Revert packed seq extra checks (2180) into r0.3.0#2196
Conversation
Signed-off-by: Chen Cui <chcui@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
|
/ok to test 75f3195 |
📝 WalkthroughWalkthroughRefactored packed sequence utilities to replace assertion-based conditional logic with direct tensor slicing operations. Simplified control flow by removing assertions that enforced specific argmin positions and padding values. Corresponding unit tests for no-padding and unpadded scenarios were removed to align with the updated implementation. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/training/utils/packed_seq_utils.py`:
- Around line 46-55: The slice bounds use torch.argmin() (which returns a 0-d
tensor) causing CUDA/device-index issues; update the fallback branches to
convert the argmin results to Python ints (use .item()) before slicing
cu_seqlens_padded and cu_seqlens_unpadded so that cu_seqlens_padded =
cu_seqlens_padded[: torch.argmin(cu_seqlens_padded).item()] and
cu_seqlens_unpadded = cu_seqlens_unpadded[:
torch.argmin(cu_seqlens_unpadded).item()], matching the existing .item() usage
in the other branches for consistent, device-safe indexing of cu_seqlens_padded
and cu_seqlens_unpadded.
| if cu_seqlens_argmin is not None: | ||
| argmin_idx = cu_seqlens_argmin.item() | ||
| assert argmin_idx == 0 or cu_seqlens_padded[argmin_idx] == -1 # cu_seqlens padding is -1 | ||
| cu_seqlens_padded = cu_seqlens_padded[:argmin_idx] | ||
| elif torch.min(cu_seqlens_padded) == -1: | ||
| cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()] | ||
| else: | ||
| cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded)] | ||
|
|
||
| if cu_seqlens_unpadded is not None: | ||
| if cu_seqlens_unpadded_argmin is not None: | ||
| argmin_idx = cu_seqlens_unpadded_argmin.item() | ||
| assert argmin_idx == 0 or cu_seqlens_unpadded[argmin_idx] == -1 # cu_seqlens padding is -1 | ||
| cu_seqlens_unpadded = cu_seqlens_unpadded[:argmin_idx] | ||
| elif torch.min(cu_seqlens_unpadded) == -1: | ||
| cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()] | ||
| else: | ||
| cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)] |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
PyTorch tensor slicing with 0-dimensional tensor indices CUDA CPU compatibility
💡 Result:
In PyTorch, a 0‑dim tensor used as an index is still a “tensor index” (advanced indexing), so PyTorch applies the device/dtype rules for index tensors:
- Index tensors must be on CPU or on the same device as the tensor being indexed; otherwise you’ll get errors like: “indices should be either on cpu or on the same device as the indexed tensor”. [1]
- On CUDA, some invalid tensor-indexing cases (e.g., out-of-bounds when the index is a tensor) may surface as a device-side assert rather than a clean
IndexError, which can be harder to catch/debug. [2]
Practical patterns for CPU/CUDA compatibility
1) If you want “scalar slicing”, convert the 0‑dim tensor to a Python int
i = int(i0.item()) # i0 can be CPU or CUDA 0-dim
y = x[i] # uses basic indexing; avoids tensor-index device rulesThis is the most portable fix. (Note: .item() on a CUDA tensor synchronizes.)
2) If you want to keep it as a tensor index, keep devices consistent
idx = idx.to(x.device) # and ensure integer dtype (usually torch.long)
y = x[idx]3) Debugging CUDA “device-side assert” from indexing
If you hit a device-side assert during indexing, reproduce with CUDA_LAUNCH_BLOCKING=1 to get a more accurate stack trace; this class of issue is commonly triggered by invalid CUDA indexing with tensor indices. [2]
Sources:
[1] PyTorch Forums: “Indices should be either on cpu or on the same device as the indexed tensor” (discussion of the device rule for index tensors). [1]
[2] PyTorch GitHub issue discussing CUDA tensor indexing leading to device-side asserts instead of IndexError. [2]
🏁 Script executed:
# First, let's examine the actual file to understand the context
cat -n src/megatron/bridge/training/utils/packed_seq_utils.py | head -70Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 3436
🏁 Script executed:
# Let's verify the full context and see if there's any processing of these cumulative seqlens
cat -n src/megatron/bridge/training/utils/packed_seq_utils.py | sed -n '36,70p'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1818
Convert torch.argmin() results to Python int for CUDA compatibility and consistent slicing behavior.
Lines 49 and 55 use torch.argmin() directly as a slice bound, which returns a 0-dimensional tensor. On CUDA, this can trigger device-side assertion errors because PyTorch requires index tensors to match the device of the indexed tensor. Lines 47 and 53 already use .item() correctly; apply the same pattern to the fallback paths.
Proposed fix
if cu_seqlens_argmin is not None:
cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()]
else:
- cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded)]
+ cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded).item()]
if cu_seqlens_unpadded is not None:
if cu_seqlens_unpadded_argmin is not None:
cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()]
else:
- cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)]
+ cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded).item()]🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/utils/packed_seq_utils.py` around lines 46 - 55,
The slice bounds use torch.argmin() (which returns a 0-d tensor) causing
CUDA/device-index issues; update the fallback branches to convert the argmin
results to Python ints (use .item()) before slicing cu_seqlens_padded and
cu_seqlens_unpadded so that cu_seqlens_padded = cu_seqlens_padded[:
torch.argmin(cu_seqlens_padded).item()] and cu_seqlens_unpadded =
cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded).item()], matching the
existing .item() usage in the other branches for consistent, device-safe
indexing of cu_seqlens_padded and cu_seqlens_unpadded.
beep boop [🤖]: Hi @cuichenx 👋,
Summary by CodeRabbit
Refactor
Tests