diff --git a/tests/test_utils.py b/tests/test_utils.py index b95b0611a09..8570ea4bf57 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -28,7 +28,6 @@ RepeatSampler, entropy_from_logits, flush_left, - flush_right, forward_masked_logits, generate_model_card, get_peft_config, @@ -383,50 +382,6 @@ def test_no_tensors(self): assert torch.equal(new_mask, expected_mask) -class TestFlushRight(TrlTestCase): - def test_basic_case(self): - mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]]) - tensor1 = torch.tensor([[2, 3, 4, 0, 0], [0, 0, 5, 6, 0]]) - tensor2 = torch.tensor([[7, 8, 9, 0, 0], [0, 0, 10, 11, 0]]) - new_mask, new_tensor1, new_tensor2 = flush_right(mask, tensor1, tensor2) - - expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) - expected_tensor1 = torch.tensor([[2, 3, 4], [0, 5, 6]]) - expected_tensor2 = torch.tensor([[7, 8, 9], [0, 10, 11]]) - - assert torch.equal(new_mask, expected_mask) - assert torch.equal(new_tensor1, expected_tensor1) - assert torch.equal(new_tensor2, expected_tensor2) - - def test_single_row(self): - mask = torch.tensor([[1, 1, 0, 0]]) - tensor1 = torch.tensor([[2, 3, 0, 0]]) - new_mask, new_tensor1 = flush_right(mask, tensor1) - - expected_mask = torch.tensor([[1, 1]]) - expected_tensor1 = torch.tensor([[2, 3]]) - - assert torch.equal(new_mask, expected_mask) - assert torch.equal(new_tensor1, expected_tensor1) - - def test_no_shift_needed(self): - mask = torch.tensor([[0, 0, 1, 1], [0, 0, 0, 1]]) - tensor1 = torch.tensor([[0, 0, 5, 6], [0, 0, 0, 7]]) - new_mask, new_tensor1 = flush_right(mask, tensor1) - - expected_mask = torch.tensor([[1, 1], [0, 1]]) - expected_tensor1 = torch.tensor([[5, 6], [0, 7]]) - - assert torch.equal(new_mask, expected_mask) - assert torch.equal(new_tensor1, expected_tensor1) - - def test_no_tensors(self): - mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]]) - new_mask = flush_right(mask) - expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) - assert torch.equal(new_mask, expected_mask) - - class TestRepeatRandomSampler(TrlTestCase): def test_sampler(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index fe5bf86ca03..1db504acebe 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -492,36 +492,6 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Tensor | tup return flushed_mask, *flushed_tensors -def flush_right(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]: - """ - Shift non-zero elements in the mask and corresponding tensors to the right. See `flush_left` for details. - """ - _, M = mask.shape - - # Create copy of mask and tensors - mask_copy = mask.clone() - tensors = [t.clone() for t in tensors] - - # Shift non-zero values to the right - flipped_mask = torch.fliplr(mask_copy) - first_non_zero = flipped_mask.argmax(dim=1) - pos = torch.arange(M, device=mask_copy.device).unsqueeze(0) - idx_roll = (pos - first_non_zero.unsqueeze(1)) % M - mask_roll = mask_copy.gather(1, idx_roll) - rolled_tensors = [t.gather(1, idx_roll) for t in tensors] - - # Truncate leading columns that are all zeros in mask_roll - col_sums = mask_roll.sum(dim=0) - non_empty_cols = col_sums != 0 - first_non_empty_col = int(non_empty_cols.to(torch.int8).argmax()) if non_empty_cols.any() else M - flushed_mask = mask_roll[:, first_non_empty_col:] - flushed_tensors = [t[:, first_non_empty_col:] for t in rolled_tensors] - - if not flushed_tensors: - return flushed_mask - return flushed_mask, *flushed_tensors - - def selective_log_softmax(logits, index) -> torch.Tensor: """ A memory-efficient implementation of the common `log_softmax -> gather` operation.