diff --git a/src/megatron/bridge/training/utils/packed_seq_utils.py b/src/megatron/bridge/training/utils/packed_seq_utils.py index e631be426a..98dbd6d5ac 100644 --- a/src/megatron/bridge/training/utils/packed_seq_utils.py +++ b/src/megatron/bridge/training/utils/packed_seq_utils.py @@ -44,18 +44,14 @@ def get_packed_seq_params(batch: dict[str, torch.Tensor]) -> PackedSeqParams: cu_seqlens_unpadded_argmin = batch.get("cu_seqlens_unpadded_argmin") if cu_seqlens_argmin is not None: - argmin_idx = cu_seqlens_argmin.item() - assert argmin_idx == 0 or cu_seqlens_padded[argmin_idx] == -1 # cu_seqlens padding is -1 - cu_seqlens_padded = cu_seqlens_padded[:argmin_idx] - elif torch.min(cu_seqlens_padded) == -1: + cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()] + else: cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded)] if cu_seqlens_unpadded is not None: if cu_seqlens_unpadded_argmin is not None: - argmin_idx = cu_seqlens_unpadded_argmin.item() - assert argmin_idx == 0 or cu_seqlens_unpadded[argmin_idx] == -1 # cu_seqlens padding is -1 - cu_seqlens_unpadded = cu_seqlens_unpadded[:argmin_idx] - elif torch.min(cu_seqlens_unpadded) == -1: + cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()] + else: cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)] max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None diff --git a/tests/unit_tests/training/test_gpt_step.py b/tests/unit_tests/training/test_gpt_step.py index d677c1334f..7e308f2f85 100644 --- a/tests/unit_tests/training/test_gpt_step.py +++ b/tests/unit_tests/training/test_gpt_step.py @@ -103,23 +103,6 @@ def test_packed_seq_params_with_cu_seqlens_argmin(self): assert torch.equal(result.max_seqlen_q, expected_max_seqlen) assert torch.equal(result.max_seqlen_kv, expected_max_seqlen) - def test_packed_seq_params_no_padding(self): - """Test functionality when cu_seqlens has no padding (-1 values).""" - batch = { - "cu_seqlens": torch.tensor([[0, 7, 14]], dtype=torch.int32), - "max_seqlen": torch.tensor([[10]], dtype=torch.int32), - } - - result = get_packed_seq_params(batch) - - # Verify the result is a PackedSeqParams object - assert isinstance(result, PackedSeqParams) - - # When there's no -1 padding, the tensor is returned unchanged - expected_cu_seqlens = torch.tensor([0, 7, 14], dtype=torch.int32) - assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens) - assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens) - def test_packed_seq_params_with_cu_seqlens_argmin_zero(self): """Test edge case when cu_seqlens_argmin is 0.""" batch = { @@ -246,21 +229,6 @@ def test_packed_seq_params_without_unpadded_fallback(self): assert result.cu_seqlens_q_padded is None assert result.cu_seqlens_kv_padded is None - def test_packed_seq_params_no_padding_in_cu_seqlens(self): - """Test when cu_seqlens has no -1 padding markers.""" - batch = { - "cu_seqlens": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1 padding - "max_seqlen": torch.tensor([[7]], dtype=torch.int32), - } - - result = get_packed_seq_params(batch) - - # When no -1 present and min != -1, the tensor should remain as-is - expected = torch.tensor([0, 5, 10], dtype=torch.int32) - assert torch.equal(result.cu_seqlens_q, expected) - # Padded fields are None when cu_seqlens_unpadded is not provided - assert result.cu_seqlens_q_padded is None - def test_packed_seq_params_qkv_format_is_thd(self): """Test that qkv_format is always set to 'thd'.""" batch = { @@ -271,22 +239,6 @@ def test_packed_seq_params_qkv_format_is_thd(self): assert result.qkv_format == "thd" - def test_packed_seq_params_cu_seqlens_unpadded_no_padding(self): - """Test cu_seqlens_unpadded with no padding markers.""" - batch = { - "cu_seqlens": torch.tensor([[0, 6, 12]], dtype=torch.int32), - "cu_seqlens_unpadded": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1 - } - - result = get_packed_seq_params(batch) - - # Unpadded should be used as-is since no -1 and min != -1 - expected_unpadded = torch.tensor([0, 5, 10], dtype=torch.int32) - expected_padded = torch.tensor([0, 6, 12], dtype=torch.int32) - - assert torch.equal(result.cu_seqlens_q, expected_unpadded) - assert torch.equal(result.cu_seqlens_q_padded, expected_padded) - class TestCreateLossFunction: """Tests for the _create_loss_function helper function."""