Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/megatron/bridge/training/utils/packed_seq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,14 @@ def get_packed_seq_params(batch: dict[str, torch.Tensor]) -> PackedSeqParams:
cu_seqlens_unpadded_argmin = batch.get("cu_seqlens_unpadded_argmin")

if cu_seqlens_argmin is not None:
argmin_idx = cu_seqlens_argmin.item()
assert argmin_idx == 0 or cu_seqlens_padded[argmin_idx] == -1 # cu_seqlens padding is -1
cu_seqlens_padded = cu_seqlens_padded[:argmin_idx]
elif torch.min(cu_seqlens_padded) == -1:
cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()]
else:
cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded)]

if cu_seqlens_unpadded is not None:
if cu_seqlens_unpadded_argmin is not None:
argmin_idx = cu_seqlens_unpadded_argmin.item()
assert argmin_idx == 0 or cu_seqlens_unpadded[argmin_idx] == -1 # cu_seqlens padding is -1
cu_seqlens_unpadded = cu_seqlens_unpadded[:argmin_idx]
elif torch.min(cu_seqlens_unpadded) == -1:
cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()]
else:
cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)]

max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None
Expand Down
48 changes: 0 additions & 48 deletions tests/unit_tests/training/test_gpt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,6 @@ def test_packed_seq_params_with_cu_seqlens_argmin(self):
assert torch.equal(result.max_seqlen_q, expected_max_seqlen)
assert torch.equal(result.max_seqlen_kv, expected_max_seqlen)

def test_packed_seq_params_no_padding(self):
"""Test functionality when cu_seqlens has no padding (-1 values)."""
batch = {
"cu_seqlens": torch.tensor([[0, 7, 14]], dtype=torch.int32),
"max_seqlen": torch.tensor([[10]], dtype=torch.int32),
}

result = get_packed_seq_params(batch)

# Verify the result is a PackedSeqParams object
assert isinstance(result, PackedSeqParams)

# When there's no -1 padding, the tensor is returned unchanged
expected_cu_seqlens = torch.tensor([0, 7, 14], dtype=torch.int32)
assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens)
assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens)

def test_packed_seq_params_with_cu_seqlens_argmin_zero(self):
"""Test edge case when cu_seqlens_argmin is 0."""
batch = {
Expand Down Expand Up @@ -246,21 +229,6 @@ def test_packed_seq_params_without_unpadded_fallback(self):
assert result.cu_seqlens_q_padded is None
assert result.cu_seqlens_kv_padded is None

def test_packed_seq_params_no_padding_in_cu_seqlens(self):
"""Test when cu_seqlens has no -1 padding markers."""
batch = {
"cu_seqlens": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1 padding
"max_seqlen": torch.tensor([[7]], dtype=torch.int32),
}

result = get_packed_seq_params(batch)

# When no -1 present and min != -1, the tensor should remain as-is
expected = torch.tensor([0, 5, 10], dtype=torch.int32)
assert torch.equal(result.cu_seqlens_q, expected)
# Padded fields are None when cu_seqlens_unpadded is not provided
assert result.cu_seqlens_q_padded is None

def test_packed_seq_params_qkv_format_is_thd(self):
"""Test that qkv_format is always set to 'thd'."""
batch = {
Expand All @@ -271,22 +239,6 @@ def test_packed_seq_params_qkv_format_is_thd(self):

assert result.qkv_format == "thd"

def test_packed_seq_params_cu_seqlens_unpadded_no_padding(self):
"""Test cu_seqlens_unpadded with no padding markers."""
batch = {
"cu_seqlens": torch.tensor([[0, 6, 12]], dtype=torch.int32),
"cu_seqlens_unpadded": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1
}

result = get_packed_seq_params(batch)

# Unpadded should be used as-is since no -1 and min != -1
expected_unpadded = torch.tensor([0, 5, 10], dtype=torch.int32)
expected_padded = torch.tensor([0, 6, 12], dtype=torch.int32)

assert torch.equal(result.cu_seqlens_q, expected_unpadded)
assert torch.equal(result.cu_seqlens_q_padded, expected_padded)


class TestCreateLossFunction:
"""Tests for the _create_loss_function helper function."""
Expand Down