diff --git a/pytext/data/__init__.py b/pytext/data/__init__.py index e49db6859..dd8301ee2 100644 --- a/pytext/data/__init__.py +++ b/pytext/data/__init__.py @@ -12,6 +12,7 @@ from .data_handler import BatchIterator, CommonMetadata, DataHandler from .disjoint_multitask_data import DisjointMultitaskData from .disjoint_multitask_data_handler import DisjointMultitaskDataHandler +from .dynamic_pooling_batcher import DynamicPoolingBatcher from .tensorizers import Tensorizer @@ -25,6 +26,7 @@ "DataHandler", "DisjointMultitaskData", "DisjointMultitaskDataHandler", + "DynamicPoolingBatcher", "EvalBatchSampler", "generator_iterator", "PoolingBatcher", diff --git a/pytext/data/data.py b/pytext/data/data.py index 3f9248edb..24b893d77 100644 --- a/pytext/data/data.py +++ b/pytext/data/data.py @@ -119,6 +119,9 @@ def __init__( self.pool_num_batches = pool_num_batches self.num_shuffled_pools = num_shuffled_pools + def get_batch_size(self, stage: Stage) -> int: + return self._batch_sizes[stage] + def batchify( self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN ): @@ -130,7 +133,7 @@ def batchify( 3. Sort rows, if necessary. 4. Shuffle the order in which the batches are returned, if necessary. """ - batch_size = self._batch_sizes[stage] + batch_size = self.get_batch_size(stage) pool_size = batch_size * self.pool_num_batches super_pool_size = pool_size * self.num_shuffled_pools diff --git a/pytext/data/dynamic_pooling_batcher.py b/pytext/data/dynamic_pooling_batcher.py new file mode 100644 index 000000000..6c4e95d14 --- /dev/null +++ b/pytext/data/dynamic_pooling_batcher.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import math +from typing import Iterable + +from pytext.common.constants import Stage +from pytext.config.module_config import ModuleConfig + +from .data import PoolingBatcher +from .sources import RawExample + + +class BatcherSchedulerConfig(ModuleConfig): + # the initial batch size used for training, this is per node + start_batch_size: int = 32 + + # the final or max batch size to use, any scheduler should + # not go over this batch size, this is per node + end_batch_size: int = 256 + + # the number of epochs to increase the batch size over + epoch_period: int = 10 + + # the batch size is kept constant for `step_size` number of epochs + step_size: int = 1 + + +class ExponentialBatcherSchedulerConfig(BatcherSchedulerConfig): + # the gamma to increase batch size by + gamma: float = 5 + + +class DynamicPoolingBatcher(PoolingBatcher): + """ + Allows dynamic batch training, extends pooling batcher with a scheduler + config, which specifies how batch size should increase + """ + + class Config(PoolingBatcher.Config): + scheduler_config: BatcherSchedulerConfig = BatcherSchedulerConfig() + + @classmethod + def from_config(cls, config: Config): + return cls( + config.train_batch_size, + config.eval_batch_size, + config.test_batch_size, + config.pool_num_batches, + config.num_shuffled_pools, + config.scheduler_config, + ) + + def __init__( + self, + train_batch_size=Config.train_batch_size, + eval_batch_size=Config.eval_batch_size, + test_batch_size=Config.test_batch_size, + pool_num_batches=Config.pool_num_batches, + num_shuffled_pools=Config.num_shuffled_pools, + scheduler_config=Config.scheduler_config, + ): + super().__init__( + train_batch_size, + eval_batch_size, + test_batch_size, + pool_num_batches, + num_shuffled_pools, + ) + self.scheduler_config = scheduler_config + self.curr_batch_size: int = 0 + self.curr_epoch: int = -1 + self.step_epoch() + + def compute_dynamic_batch_size( + self, curr_epoch: int, scheduler_config: BatcherSchedulerConfig, curr_steps: int + ) -> int: + raise NotImplementedError() + + def step_epoch(self): + self.curr_epoch += 1 + if self.curr_epoch > self.scheduler_config.epoch_period: + # if past the dynamic period just return + # the batch size + self.curr_batch_size = self.scheduler_config.end_batch_size + else: + if self.curr_epoch % self.scheduler_config.step_size == 0: + old_size = self.curr_batch_size + new_size = self.compute_dynamic_batch_size( + curr_epoch=self.curr_epoch, + scheduler_config=self.scheduler_config, + curr_steps=self.curr_epoch // self.scheduler_config.step_size, + ) + + print(f"increasing batch size from {old_size} to {new_size}") + + self.curr_batch_size = new_size + + def finished_dynamic(self) -> bool: + return ( + self.scheduler_config.epoch_period != -1 + and self.curr_epoch >= self.scheduler_config.epoch_period + ) + + def get_batch_size(self, stage: Stage) -> int: + if stage == Stage.TRAIN: + print(f"using dynamic batch size {self.curr_batch_size}") + return self.curr_batch_size + else: + return self._batch_sizes[stage] + + def batchify( + self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN + ): + """ + From an iterable of dicts, yield dicts of lists: + + 1. Load `num_shuffled_pools` pools of data, and shuffle them. + 2. Load a pool (`batch_size * pool_num_batches` examples). + 3. Sort rows, if necessary. + 4. Shuffle the order in which the batches are returned, if necessary. + """ + for item in super().batchify(iterable=iterable, sort_key=sort_key, stage=stage): + yield item + if stage == Stage.TRAIN: + # only step scheduler when in train + self.step_epoch() + + +class LinearDynamicPoolingBatcher(DynamicPoolingBatcher): + """ + Linear Dynamic Batch Scheduler: scales up batch size linearly + """ + + def compute_dynamic_batch_size( + self, curr_epoch: int, scheduler_config: BatcherSchedulerConfig, curr_steps: int + ) -> int: + batch_delta = ( + scheduler_config.end_batch_size - scheduler_config.start_batch_size + ) + curr_est: float = ( + batch_delta / (scheduler_config.epoch_period / scheduler_config.step_size) + ) * curr_steps + scheduler_config.start_batch_size + return math.ceil(curr_est) + + +class ExponentialDynamicPoolingBatcher(DynamicPoolingBatcher): + """ + Exponential Dynamic Batch Scheduler: scales up batch size by a factor of + gamma + """ + + class Config(DynamicPoolingBatcher.Config): + scheduler_config: ExponentialBatcherSchedulerConfig + + def __init__(self, *args, **kwargs): + self.max_steps = None + super().__init__(*args, **kwargs) + + def get_max_steps(self): + if self.max_steps: + return self.max_steps + self.max_steps: float = math.floor( + math.log( + self.scheduler_config.end_batch_size + / self.scheduler_config.start_batch_size + ) + / math.log(self.scheduler_config.gamma) + ) + + return self.max_steps + + def finished_dynamic(self) -> bool: + return ( + self.scheduler_config.epoch_period != -1 + and self.curr_epoch >= self.get_max_steps() + ) + + def compute_dynamic_batch_size( + self, + curr_epoch: int, + scheduler_config: ExponentialBatcherSchedulerConfig, + curr_steps: int, + ) -> int: + if curr_steps > self.get_max_steps(): + return scheduler_config.end_batch_size + curr_est: float = scheduler_config.start_batch_size * math.pow( + scheduler_config.gamma, curr_steps + ) + return min(math.ceil(curr_est), scheduler_config.end_batch_size) diff --git a/pytext/data/test/dynamic_pooling_batcher_test.py b/pytext/data/test/dynamic_pooling_batcher_test.py new file mode 100644 index 000000000..463552dea --- /dev/null +++ b/pytext/data/test/dynamic_pooling_batcher_test.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import unittest +from typing import List + +from pytext.data.dynamic_pooling_batcher import ( + BatcherSchedulerConfig, + DynamicPoolingBatcher, + ExponentialBatcherSchedulerConfig, + ExponentialDynamicPoolingBatcher, + LinearDynamicPoolingBatcher, +) +from pytext.data.sources import RawExample +from pytext.utils.test import import_tests_module + + +tests_module = import_tests_module() + + +class DynamicPoolingBatcherTest(unittest.TestCase): + @classmethod + def _get_dataset(cls, dataset_size: int) -> List[RawExample]: + return [([], {})] * dataset_size + + def test_linear_scheduler(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=100) + + batch_scheduler_config = BatcherSchedulerConfig( + start_batch_size=32, end_batch_size=256, epoch_period=5, step_size=1 + ) + + batcher_config = DynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = LinearDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 2 + # new size ()(256-32) / 5) + 32 = 76.8 ~ 77 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 77) + + def test_exponential_scheduler(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=100) + + batch_scheduler_config = ExponentialBatcherSchedulerConfig( + start_batch_size=32, + end_batch_size=256, + epoch_period=5, + step_size=1, + gamma=2, + ) + + batcher_config = ExponentialDynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 2 + # new size 32 * 2^1 = 64 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 64) + + def test_batch_size_greater_than_data(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=50) + + batch_scheduler_config = ExponentialBatcherSchedulerConfig( + start_batch_size=32, + end_batch_size=256, + epoch_period=5, + step_size=1, + gamma=2, + ) + + batcher_config = ExponentialDynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 2 + # new size 32 * 2^1 = 64 / 8 = 8 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 50) + + def end_of_scheduler(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=300) + + batch_scheduler_config = ExponentialBatcherSchedulerConfig( + start_batch_size=32, + end_batch_size=256, + epoch_period=2, + step_size=4, + gamma=2, + ) + + batcher_config = ExponentialDynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # pass N epochs + no_op_epochs = 4 + _ = [[item for item in batcher.batchify(data)] for _ in range(no_op_epochs)] + + # after period is passed, batch size should be max batch size + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 256) + + def test_step_size(self): + data = DynamicPoolingBatcherTest._get_dataset(dataset_size=64) + + batch_scheduler_config = ExponentialBatcherSchedulerConfig( + start_batch_size=32, + end_batch_size=256, + epoch_period=2, + step_size=2, + gamma=2, + ) + + batcher_config = ExponentialDynamicPoolingBatcher.Config( + train_batch_size=1, + eval_batch_size=1, + test_batch_size=1, + pool_num_batches=1, + num_shuffled_pools=1, + scheduler_config=batch_scheduler_config, + ) + + batcher = ExponentialDynamicPoolingBatcher.from_config(batcher_config) + + # epoch 1 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 2 + # no op on batch size + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 32) + + # epoch 3 + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 64) + + # epoch 4 + # no op on batch size + batches = [item for item in batcher.batchify(data)] + self.assertEqual(len(batches[0].raw_data), 64) diff --git a/pytext/docs/make_config_docs.py b/pytext/docs/make_config_docs.py index d5363982b..230c1b65d 100644 --- a/pytext/docs/make_config_docs.py +++ b/pytext/docs/make_config_docs.py @@ -64,6 +64,14 @@ def find_additional_configs(configs, seen=None): CONFIG_WRAPPERS = [Config.from_config(config) for config in ALL_CONFIG_CLASSES] +# TODO: REMOVE THIS, just useful for debugging the current diff +# will not be committed +config_set = set() +for config in CONFIG_WRAPPERS: + if config.path in config_set: + print(f"{config.path} double counted") + config_set.add(config.path) + assert len({config.path for config in CONFIG_WRAPPERS}) == len( CONFIG_WRAPPERS ), "Some configs had exactly identical config paths!"