Skip to content

Commit

Permalink
Dynamic Batch Scheduler Implementation (facebookresearch#1200)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#1200

This diff adds support of dynamic batch training in pytext. It creates a new batcher that computes the batch size depending on the current epoch.

The diff implements two schedulers:
* **Linear**: increases batch size linearly
* **Exponential**: increases batch size exponentially

# API
The dynamic batcher extends the pooling batcher so that most the arguments are there and are pretty consistent. It is important to note that dynamic batch sizes only affects the training batch size, not eval or test.

Dynamic batcher holds a new configuration object `scheduler_config`, this contains the information needed to compute dynamic batch sizes namely:

```
class SchedulerConfig(ModuleConfig):
  # the initial batch size used for training
  start_batch_size: int = 32

  # the final or max batch size to use, any scheduler should
  # not go over this batch size
  end_batch_size: int = 256

  # the number of epochs to increase the batch size over
  epoch_period: int = 10

  # the batch size is kept constant for `step_size` number of epochs
  step_size: int = 1
```

Paper: https://arxiv.org/abs/1711.00489

Reviewed By: seayoung1112, ArmenAg

Differential Revision: D18900677

fbshipit-source-id: 1f1e8ec850b84778a3e7a2c505a7c500de0ad2d5
  • Loading branch information
Akshat Shrivastava authored and facebook-github-bot committed Dec 17, 2019
1 parent 5a7187c commit 146cd50
Show file tree
Hide file tree
Showing 4 changed files with 383 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pytext/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .data_handler import BatchIterator, CommonMetadata, DataHandler
from .disjoint_multitask_data import DisjointMultitaskData
from .disjoint_multitask_data_handler import DisjointMultitaskDataHandler
from .dynamic_pooling_batcher import DynamicPoolingBatcher
from .tensorizers import Tensorizer


Expand All @@ -25,6 +26,7 @@
"DataHandler",
"DisjointMultitaskData",
"DisjointMultitaskDataHandler",
"DynamicPoolingBatcher",
"EvalBatchSampler",
"generator_iterator",
"PoolingBatcher",
Expand Down
5 changes: 4 additions & 1 deletion pytext/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __init__(
self.pool_num_batches = pool_num_batches
self.num_shuffled_pools = num_shuffled_pools

def get_batch_size(self, stage: Stage) -> int:
return self._batch_sizes[stage]

def batchify(
self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN
):
Expand All @@ -130,7 +133,7 @@ def batchify(
3. Sort rows, if necessary.
4. Shuffle the order in which the batches are returned, if necessary.
"""
batch_size = self._batch_sizes[stage]
batch_size = self.get_batch_size(stage)
pool_size = batch_size * self.pool_num_batches
super_pool_size = pool_size * self.num_shuffled_pools

Expand Down
190 changes: 190 additions & 0 deletions pytext/data/dynamic_pooling_batcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import math
from typing import Iterable

from pytext.common.constants import Stage
from pytext.config.module_config import ModuleConfig

from .data import PoolingBatcher
from .sources import RawExample


class BatcherSchedulerConfig(ModuleConfig):
# the initial batch size used for training, this is per node
start_batch_size: int = 32

# the final or max batch size to use, any scheduler should
# not go over this batch size, this is per node
end_batch_size: int = 256

# the number of epochs to increase the batch size over
epoch_period: int = 10

# the batch size is kept constant for `step_size` number of epochs
step_size: int = 1


class ExponentialBatcherSchedulerConfig(BatcherSchedulerConfig):
# the gamma to increase batch size by
gamma: float = 5


class DynamicPoolingBatcher(PoolingBatcher):
"""
Allows dynamic batch training, extends pooling batcher with a scheduler
config, which specifies how batch size should increase
"""

class Config(PoolingBatcher.Config):
scheduler_config: BatcherSchedulerConfig = BatcherSchedulerConfig()

@classmethod
def from_config(cls, config: Config):
return cls(
config.train_batch_size,
config.eval_batch_size,
config.test_batch_size,
config.pool_num_batches,
config.num_shuffled_pools,
config.scheduler_config,
)

def __init__(
self,
train_batch_size=Config.train_batch_size,
eval_batch_size=Config.eval_batch_size,
test_batch_size=Config.test_batch_size,
pool_num_batches=Config.pool_num_batches,
num_shuffled_pools=Config.num_shuffled_pools,
scheduler_config=Config.scheduler_config,
):
super().__init__(
train_batch_size,
eval_batch_size,
test_batch_size,
pool_num_batches,
num_shuffled_pools,
)
self.scheduler_config = scheduler_config
self.curr_batch_size: int = 0
self.curr_epoch: int = -1
self.step_epoch()

def compute_dynamic_batch_size(
self, curr_epoch: int, scheduler_config: BatcherSchedulerConfig, curr_steps: int
) -> int:
raise NotImplementedError()

def step_epoch(self):
self.curr_epoch += 1
if self.curr_epoch > self.scheduler_config.epoch_period:
# if past the dynamic period just return
# the batch size
self.curr_batch_size = self.scheduler_config.end_batch_size
else:
if self.curr_epoch % self.scheduler_config.step_size == 0:
old_size = self.curr_batch_size
new_size = self.compute_dynamic_batch_size(
curr_epoch=self.curr_epoch,
scheduler_config=self.scheduler_config,
curr_steps=self.curr_epoch // self.scheduler_config.step_size,
)

print(f"increasing batch size from {old_size} to {new_size}")

self.curr_batch_size = new_size

def finished_dynamic(self) -> bool:
return (
self.scheduler_config.epoch_period != -1
and self.curr_epoch >= self.scheduler_config.epoch_period
)

def get_batch_size(self, stage: Stage) -> int:
if stage == Stage.TRAIN:
print(f"using dynamic batch size {self.curr_batch_size}")
return self.curr_batch_size
else:
return self._batch_sizes[stage]

def batchify(
self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN
):
"""
From an iterable of dicts, yield dicts of lists:
1. Load `num_shuffled_pools` pools of data, and shuffle them.
2. Load a pool (`batch_size * pool_num_batches` examples).
3. Sort rows, if necessary.
4. Shuffle the order in which the batches are returned, if necessary.
"""
for item in super().batchify(iterable=iterable, sort_key=sort_key, stage=stage):
yield item
if stage == Stage.TRAIN:
# only step scheduler when in train
self.step_epoch()


class LinearDynamicPoolingBatcher(DynamicPoolingBatcher):
"""
Linear Dynamic Batch Scheduler: scales up batch size linearly
"""

def compute_dynamic_batch_size(
self, curr_epoch: int, scheduler_config: BatcherSchedulerConfig, curr_steps: int
) -> int:
batch_delta = (
scheduler_config.end_batch_size - scheduler_config.start_batch_size
)
curr_est: float = (
batch_delta / (scheduler_config.epoch_period / scheduler_config.step_size)
) * curr_steps + scheduler_config.start_batch_size
return math.ceil(curr_est)


class ExponentialDynamicPoolingBatcher(DynamicPoolingBatcher):
"""
Exponential Dynamic Batch Scheduler: scales up batch size by a factor of
gamma
"""

class Config(DynamicPoolingBatcher.Config):
scheduler_config: ExponentialBatcherSchedulerConfig

def __init__(self, *args, **kwargs):
self.max_steps = None
super().__init__(*args, **kwargs)

def get_max_steps(self):
if self.max_steps:
return self.max_steps
self.max_steps: float = math.floor(
math.log(
self.scheduler_config.end_batch_size
/ self.scheduler_config.start_batch_size
)
/ math.log(self.scheduler_config.gamma)
)

return self.max_steps

def finished_dynamic(self) -> bool:
return (
self.scheduler_config.epoch_period != -1
and self.curr_epoch >= self.get_max_steps()
)

def compute_dynamic_batch_size(
self,
curr_epoch: int,
scheduler_config: ExponentialBatcherSchedulerConfig,
curr_steps: int,
) -> int:
if curr_steps > self.get_max_steps():
return scheduler_config.end_batch_size
curr_est: float = scheduler_config.start_batch_size * math.pow(
scheduler_config.gamma, curr_steps
)
return min(math.ceil(curr_est), scheduler_config.end_batch_size)
Loading

0 comments on commit 146cd50

Please sign in to comment.