From cfe4281b97b5475ad4b1da919c13baefcfc21404 Mon Sep 17 00:00:00 2001 From: Akshat Shrivastava Date: Tue, 17 Dec 2019 14:13:13 -0800 Subject: [PATCH] Dynamic Batch Scheduler Implementation (#1200) Summary: Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1200 This diff adds support of dynamic batch training in pytext. It creates a new batcher that computes the batch size depending on the current epoch. The diff implements two schedulers: * **Linear**: increases batch size linearly * **Exponential**: increases batch size exponentially # API The dynamic batcher extends the pooling batcher so that most the arguments are there and are pretty consistent. It is important to note that dynamic batch sizes only affects the training batch size, not eval or test. Dynamic batcher holds a new configuration object `scheduler_config`, this contains the information needed to compute dynamic batch sizes namely: ``` class SchedulerConfig(ModuleConfig): # the initial batch size used for training start_batch_size: int = 32 # the final or max batch size to use, any scheduler should # not go over this batch size end_batch_size: int = 256 # the number of epochs to increase the batch size over epoch_period: int = 10 # the batch size is kept constant for `step_size` number of epochs step_size: int = 1 ``` Paper: https://arxiv.org/abs/1711.00489 Reviewed By: seayoung1112, ArmenAg Differential Revision: D18900677 fbshipit-source-id: 63cff4597e04922804c3d551afb6435a6f719591 --- pytext/data/__init__.py | 2 + pytext/data/data.py | 5 +- pytext/data/dynamic_pooling_batcher.py | 190 ++++++++++++++++++ .../data/test/dynamic_pooling_batcher_test.py | 187 +++++++++++++++++ 4 files changed, 383 insertions(+), 1 deletion(-) create mode 100644 pytext/data/dynamic_pooling_batcher.py create mode 100644 pytext/data/test/dynamic_pooling_batcher_test.py diff --git a/pytext/data/__init__.py b/pytext/data/__init__.py index e49db6859..dd8301ee2 100644 --- a/pytext/data/__init__.py +++ b/pytext/data/__init__.py @@ -12,6 +12,7 @@ from .data_handler import BatchIterator, CommonMetadata, DataHandler from .disjoint_multitask_data import DisjointMultitaskData from .disjoint_multitask_data_handler import DisjointMultitaskDataHandler +from .dynamic_pooling_batcher import DynamicPoolingBatcher from .tensorizers import Tensorizer @@ -25,6 +26,7 @@ "DataHandler", "DisjointMultitaskData", "DisjointMultitaskDataHandler", + "DynamicPoolingBatcher", "EvalBatchSampler", "generator_iterator", "PoolingBatcher", diff --git a/pytext/data/data.py b/pytext/data/data.py index 3f9248edb..24b893d77 100644 --- a/pytext/data/data.py +++ b/pytext/data/data.py @@ -119,6 +119,9 @@ def __init__( self.pool_num_batches = pool_num_batches self.num_shuffled_pools = num_shuffled_pools + def get_batch_size(self, stage: Stage) -> int: + return self._batch_sizes[stage] + def batchify( self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN ): @@ -130,7 +133,7 @@ def batchify( 3. Sort rows, if necessary. 4. Shuffle the order in which the batches are returned, if necessary. """ - batch_size = self._batch_sizes[stage] + batch_size = self.get_batch_size(stage) pool_size = batch_size * self.pool_num_batches super_pool_size = pool_size * self.num_shuffled_pools diff --git a/pytext/data/dynamic_pooling_batcher.py b/pytext/data/dynamic_pooling_batcher.py new file mode 100644 index 000000000..6c4e95d14 --- /dev/null +++ b/pytext/data/dynamic_pooling_batcher.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import math +from typing import Iterable + +from pytext.common.constants import Stage +from pytext.config.module_config import ModuleConfig + +from .data import PoolingBatcher +from .sources import RawExample + + +class BatcherSchedulerConfig(ModuleConfig): + # the initial batch size used for training, this is per node + start_batch_size: int = 32 + + # the final or max batch size to use, any scheduler should + # not go over this batch size, this is per node + end_batch_size: int = 256 + + # the number of epochs to increase the batch size over + epoch_period: int = 10 + + # the batch size is kept constant for `step_size` number of epochs + step_size: int = 1 + + +class ExponentialBatcherSchedulerConfig(BatcherSchedulerConfig): + # the gamma to increase batch size by + gamma: float = 5 + + +class DynamicPoolingBatcher(PoolingBatcher): + """ + Allows dynamic batch training, extends pooling batcher with a scheduler + config, which specifies how batch size should increase + """ + + class Config(PoolingBatcher.Config): + scheduler_config: BatcherSchedulerConfig = BatcherSchedulerConfig() + + @classmethod + def from_config(cls, config: Config): + return cls( + config.train_batch_size, + config.eval_batch_size, + config.test_batch_size, + config.pool_num_batches, + config.num_shuffled_pools, + config.scheduler_config, + ) + + def __init__( + self, + train_batch_size=Config.train_batch_size, + eval_batch_size=Config.eval_batch_size, + test_batch_size=Config.test_batch_size, + pool_num_batches=Config.pool_num_batches, + num_shuffled_pools=Config.num_shuffled_pools, + scheduler_config=Config.scheduler_config, + ): + super().__init__( + train_batch_size, + eval_batch_size, + test_batch_size, + pool_num_batches, + num_shuffled_pools, + ) + self.scheduler_config = scheduler_config + self.curr_batch_size: int = 0 + self.curr_epoch: int = -1 + self.step_epoch() + + def compute_dynamic_batch_size( + self, curr_epoch: int, scheduler_config: BatcherSchedulerConfig, curr_steps: int + ) -> int: + raise NotImplementedError() + + def step_epoch(self): + self.curr_epoch += 1 + if self.curr_epoch > self.scheduler_config.epoch_period: + # if past the dynamic period just return + # the batch size + self.curr_batch_size = self.scheduler_config.end_batch_size + else: + if self.curr_epoch % self.scheduler_config.step_size == 0: + old_size = self.curr_batch_size + new_size = self.compute_dynamic_batch_size( + curr_epoch=self.curr_epoch, + scheduler_config=self.scheduler_config, + curr_steps=self.curr_epoch // self.scheduler_config.step_size, + ) + + print(f"increasing batch size from {old_size} to {new_size}") + + self.curr_batch_size = new_size + + def finished_dynamic(self) -> bool: + return ( + self.scheduler_config.epoch_period != -1 + and self.curr_epoch >= self.scheduler_config.epoch_period + ) + + def get_batch_size(self, stage: Stage) -> int: + if stage == Stage.TRAIN: + print(f"using dynamic batch size {self.curr_batch_size}") + return self.curr_batch_size + else: + return self._batch_sizes[stage] + + def batchify( + self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN + ): + """ + From an iterable of dicts, yield dicts of lists: + + 1. Load `num_shuffled_pools` pools of data, and shuffle them. + 2. Load a pool (`batch_size * pool_num_batches` examples). + 3. Sort rows, if necessary. + 4. Shuffle the order in which the batches are returned, if necessary. + """ + for item in super().batchify(iterable=iterable, sort_key=sort_key, stage=stage): + yield item + if stage == Stage.TRAIN: + # only step scheduler when in train + self.step_epoch() + + +class LinearDynamicPoolingBatcher(DynamicPoolingBatcher): + """ + Linear Dynamic Batch Scheduler: scales up batch size linearly + """ + + def compute_dynamic_batch_size( + self, curr_epoch: int, scheduler_config: BatcherSchedulerConfig, curr_steps: int + ) -> int: + batch_delta = ( + scheduler_config.end_batch_size - scheduler_config.start_batch_size + ) + curr_est: float = ( + batch_delta / (scheduler_config.epoch_period / scheduler_config.step_size) + ) * curr_steps + scheduler_config.start_batch_size + return math.ceil(curr_est) + + +class ExponentialDynamicPoolingBatcher(DynamicPoolingBatcher): + """ + Exponential Dynamic Batch Scheduler: scales up batch size by a factor of + gamma + """ + + class Config(DynamicPoolingBatcher.Config): + scheduler_config: ExponentialBatcherSchedulerConfig + + def __init__(self, *args, **kwargs): + self.max_steps = None + super().__init__(*args, **kwargs) + + def get_max_steps(self): + if self.max_steps: + return self.max_steps + self.max_steps: float = math.floor( + math.log( + self.scheduler_config.end_batch_size + / self.scheduler_config.start_batch_size + ) + / math.log(self.scheduler_config.gamma) + ) + + return self.max_steps + + def finished_dynamic(self) -> bool: + return ( + self.scheduler_config.epoch_period != -1 + and self.curr_epoch >= self.get_max_steps() + ) + + def compute_dynamic_batch_size( + self, + curr_epoch: int, + scheduler_config: ExponentialBatcherSchedulerConfig, + curr_steps: int, + ) -> int: + if curr_steps > self.get_max_steps(): + return scheduler_config.end_batch_size + curr_est: float = scheduler_config.start_batch_size * math.pow( + scheduler_config.gamma, curr_steps + ) + return min(math.ceil(curr_est), scheduler_config.end_batch_size) diff --git a/pytext/data/test/dynamic_pooling_batcher_test.py b/pytext/data/test/dynamic_pooling_batcher_test.py new file mode 100644 index 000000000..463552dea --- /dev/null +++ b/pytext/data/test/dynamic_pooling_batcher_test.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import unittest +from typing import List + +from pytext.data.dynamic_pooling_batcher import ( + BatcherSchedulerConfig, + DynamicPoolingBatcher, + ExponentialBatcherSchedulerConfig, + ExponentialDynamicPoolingBatcher, + LinearDynamicPoolingBatcher, +) +from pytext.data.sources import RawExample +from pytext.utils.test import import_tests_module + + +tests_module = import_tests_module() + + +class DynamicPoolingBatcherTest(unittest.TestCase): + @classmethod + def _get_dataset(cls, dataset_size: int) -> List[RawExample]: + return [([], {})] * dataset_size + + def test_linear_scheduler(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=100) + + batch_scheduler_config = BatcherSchedulerConfig( + start_batch_size=32, end_batch_size=256, epoch_period=5, step_size=1 + ) + + batcher_config = DynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = LinearDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 2 + # new size ()(256-32) / 5) + 32 = 76.8 ~ 77 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 77) + + def test_exponential_scheduler(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=100) + + batch_scheduler_config = ExponentialBatcherSchedulerConfig( + start_batch_size=32, + end_batch_size=256, + epoch_period=5, + step_size=1, + gamma=2, + ) + + batcher_config = ExponentialDynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 2 + # new size 32 * 2^1 = 64 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 64) + + def test_batch_size_greater_than_data(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=50) + + batch_scheduler_config = ExponentialBatcherSchedulerConfig( + start_batch_size=32, + end_batch_size=256, + epoch_period=5, + step_size=1, + gamma=2, + ) + + batcher_config = ExponentialDynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 2 + # new size 32 * 2^1 = 64 / 8 = 8 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 50) + + def end_of_scheduler(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=300) + + batch_scheduler_config = ExponentialBatcherSchedulerConfig( + start_batch_size=32, + end_batch_size=256, + epoch_period=2, + step_size=4, + gamma=2, + ) + + batcher_config = ExponentialDynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # pass N epochs + no_op_epochs = 4 + _ = [[item for item in batcher.batchify(data)] for _ in range(no_op_epochs)] + + # after period is passed, batch size should be max batch size + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 256) + + def test_step_size(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=64) + + batch_scheduler_config = ExponentialBatcherSchedulerConfig( + start_batch_size=32, + end_batch_size=256, + epoch_period=2, + step_size=2, + gamma=2, + ) + + batcher_config = ExponentialDynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 2 + # no op on batch size + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 3 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 64) + + # epoch 4 + # no op on batch size + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 64)