Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def _make_sampler(args: DatasetArguments, dataset: Dataset) -> Sampler:

return RandomSampler(dataset, num_samples=num_samples)
else:
return LengthAwareSampler(dataset, num_samples=num_samples)
return LengthAwareSampler(
dataset, num_samples=num_samples, batch_size=batch_size
)


def data_collator_with_truncation(
Expand Down Expand Up @@ -269,9 +271,11 @@ def __init__(
self,
data_source: Dataset,
num_samples: Optional[int] = None,
batch_size: int = 1,
) -> None:
self.data_source = data_source
self._num_samples = num_samples or len(data_source)
self.batch_size = batch_size

if "input_ids" in data_source.column_names:
feature_name = "input_ids"
Expand All @@ -284,6 +288,40 @@ def __init__(

lengths = [len(sample) for sample in data_source[feature_name]]
self.order = torch.argsort(torch.tensor(lengths), descending=True).tolist()
self._calculate_and_log_batch_stats(lengths)

def _calculate_and_log_batch_stats(self, lengths: list[int]):
if self.batch_size == 1:
return

logger.debug(
"LengthAwareSampler: Calculating batch statistics for "
f"{self.num_samples} samples with batch size {self.batch_size}"
)

sorted_lengths = [lengths[i] for i in self.order][: self.num_samples]
total_tokens_removed = 0
total_tokens_added = 0

for i in range(0, self.num_samples, self.batch_size):
batch_lengths = sorted_lengths[i : i + self.batch_size]
if not batch_lengths:
continue

shortest_in_batch = min(batch_lengths)
longest_in_batch = max(batch_lengths)
tokens_removed = sum(lgth - shortest_in_batch for lgth in batch_lengths)
tokens_added = sum(longest_in_batch - lgth for lgth in batch_lengths)

total_tokens_removed += tokens_removed
total_tokens_added += tokens_added

if total_tokens_removed > 0 or total_tokens_added > 0:
logger.debug(
f"LengthAwareSampler: Total token overhead - "
f"removed (truncation): {total_tokens_removed}, "
f"added (padding): {total_tokens_added}"
)

@property
def num_samples(self) -> int:
Expand Down
Empty file.
42 changes: 42 additions & 0 deletions tests/llmcompressor/datasets/test_length_aware_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from unittest.mock import patch

import pytest
from datasets import Dataset

from llmcompressor.datasets.utils import LengthAwareSampler


def _create_mock_dataset(lengths: list[int]) -> Dataset:
"""Create a mock dataset with input_ids of specified lengths."""
return Dataset.from_dict({"input_ids": [[0] * length for length in lengths]})


class TestLengthAwareSampler:
"""Tests for LengthAwareSampler batch statistics logging."""

@pytest.mark.unit
def test_batch_size_parameter(self):
dataset = _create_mock_dataset([100, 200, 300])
sampler = LengthAwareSampler(dataset, batch_size=4)
assert sampler.batch_size == 4

@pytest.mark.unit
def test_logging_called_when_batch_size_greater_than_one(self):
dataset = _create_mock_dataset([100, 150, 200, 250])

with patch("llmcompressor.datasets.utils.logger") as mock_logger:
LengthAwareSampler(dataset, batch_size=2)
debug_calls = [str(c) for c in mock_logger.debug.call_args_list]
assert any("Calculating batch statistics" in c for c in debug_calls)

@pytest.mark.unit
def test_tokens_added_calculation(self):
dataset = _create_mock_dataset([100, 200, 300, 150])

with patch("llmcompressor.datasets.utils.logger") as mock_logger:
LengthAwareSampler(dataset, batch_size=2)

debug_calls = [str(c) for c in mock_logger.debug.call_args_list]
assert any(
"added (padding): 150" in c for c in debug_calls
), f"Expected 'added (padding): 150' in {debug_calls}"