diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 269c45ca78..c49470025d 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -235,7 +235,9 @@ def _make_sampler(args: DatasetArguments, dataset: Dataset) -> Sampler: return RandomSampler(dataset, num_samples=num_samples) else: - return LengthAwareSampler(dataset, num_samples=num_samples) + return LengthAwareSampler( + dataset, num_samples=num_samples, batch_size=batch_size + ) def data_collator_with_truncation( @@ -269,9 +271,11 @@ def __init__( self, data_source: Dataset, num_samples: Optional[int] = None, + batch_size: int = 1, ) -> None: self.data_source = data_source self._num_samples = num_samples or len(data_source) + self.batch_size = batch_size if "input_ids" in data_source.column_names: feature_name = "input_ids" @@ -284,6 +288,40 @@ def __init__( lengths = [len(sample) for sample in data_source[feature_name]] self.order = torch.argsort(torch.tensor(lengths), descending=True).tolist() + self._calculate_and_log_batch_stats(lengths) + + def _calculate_and_log_batch_stats(self, lengths: list[int]): + if self.batch_size == 1: + return + + logger.debug( + "LengthAwareSampler: Calculating batch statistics for " + f"{self.num_samples} samples with batch size {self.batch_size}" + ) + + sorted_lengths = [lengths[i] for i in self.order][: self.num_samples] + total_tokens_removed = 0 + total_tokens_added = 0 + + for i in range(0, self.num_samples, self.batch_size): + batch_lengths = sorted_lengths[i : i + self.batch_size] + if not batch_lengths: + continue + + shortest_in_batch = min(batch_lengths) + longest_in_batch = max(batch_lengths) + tokens_removed = sum(lgth - shortest_in_batch for lgth in batch_lengths) + tokens_added = sum(longest_in_batch - lgth for lgth in batch_lengths) + + total_tokens_removed += tokens_removed + total_tokens_added += tokens_added + + if total_tokens_removed > 0 or total_tokens_added > 0: + logger.debug( + f"LengthAwareSampler: Total token overhead - " + f"removed (truncation): {total_tokens_removed}, " + f"added (padding): {total_tokens_added}" + ) @property def num_samples(self) -> int: diff --git a/tests/llmcompressor/datasets/__init__.py b/tests/llmcompressor/datasets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/llmcompressor/datasets/test_length_aware_sampler.py b/tests/llmcompressor/datasets/test_length_aware_sampler.py new file mode 100644 index 0000000000..3f19d18d99 --- /dev/null +++ b/tests/llmcompressor/datasets/test_length_aware_sampler.py @@ -0,0 +1,42 @@ +from unittest.mock import patch + +import pytest +from datasets import Dataset + +from llmcompressor.datasets.utils import LengthAwareSampler + + +def _create_mock_dataset(lengths: list[int]) -> Dataset: + """Create a mock dataset with input_ids of specified lengths.""" + return Dataset.from_dict({"input_ids": [[0] * length for length in lengths]}) + + +class TestLengthAwareSampler: + """Tests for LengthAwareSampler batch statistics logging.""" + + @pytest.mark.unit + def test_batch_size_parameter(self): + dataset = _create_mock_dataset([100, 200, 300]) + sampler = LengthAwareSampler(dataset, batch_size=4) + assert sampler.batch_size == 4 + + @pytest.mark.unit + def test_logging_called_when_batch_size_greater_than_one(self): + dataset = _create_mock_dataset([100, 150, 200, 250]) + + with patch("llmcompressor.datasets.utils.logger") as mock_logger: + LengthAwareSampler(dataset, batch_size=2) + debug_calls = [str(c) for c in mock_logger.debug.call_args_list] + assert any("Calculating batch statistics" in c for c in debug_calls) + + @pytest.mark.unit + def test_tokens_added_calculation(self): + dataset = _create_mock_dataset([100, 200, 300, 150]) + + with patch("llmcompressor.datasets.utils.logger") as mock_logger: + LengthAwareSampler(dataset, batch_size=2) + + debug_calls = [str(c) for c in mock_logger.debug.call_args_list] + assert any( + "added (padding): 150" in c for c in debug_calls + ), f"Expected 'added (padding): 150' in {debug_calls}"