From 190e5f938cb012dac59647d422dee3acbd5ddda2 Mon Sep 17 00:00:00 2001 From: Janek Ebbers Date: Thu, 21 Dec 2023 12:35:47 +0100 Subject: [PATCH 1/2] fix assess in DynamicTimeSeriesBucket when max_total_size is used --- lazy_dataset/core.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lazy_dataset/core.py b/lazy_dataset/core.py index ed8a59f..046e46b 100644 --- a/lazy_dataset/core.py +++ b/lazy_dataset/core.py @@ -3351,7 +3351,13 @@ def is_completed(self): def assess(self, example): seq_len = self.len_key(example) - return self.lower_bound <= seq_len <= self.upper_bound + return ( + (self.lower_bound <= seq_len <= self.upper_bound) + and ( + (self.max_total_size is None) + or ((len(self.data) + 1) * max(self.max_len, seq_len) <= self.max_total_size) + ) + ) def _append(self, example): super()._append(example) From 84b1b8b6a41489128afe9501b4947fddfee97bba Mon Sep 17 00:00:00 2001 From: Janek Ebbers Date: Sat, 30 Dec 2023 12:43:47 -0500 Subject: [PATCH 2/2] add test_max_total_size --- tests/test_bucket.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_bucket.py b/tests/test_bucket.py index b457f8f..c8d5764 100644 --- a/tests/test_bucket.py +++ b/tests/test_bucket.py @@ -19,3 +19,16 @@ def test_bucket(): assert dynamic_batched_buckets == [ [10, 5], [7, 8], [1, 2], [4, 3], [6, 9], [20], [1] ] + + +def test_max_total_size(): + examples = [6, 7, 9, 5, 6, 3, 7, 4] + examples = {str(j): i for j, i in enumerate(examples)} + ds = lazy_dataset.new(examples) + + dynamic_batched_buckets = list(ds.batch_dynamic_time_series_bucket( + batch_size=3, len_key=lambda x: x, max_padding_rate=0.9, max_total_size=21, + )) + assert dynamic_batched_buckets == [ + [6, 7, 5], [9, 6], [3, 7, 4] + ]