This repository has been archived by the owner on Nov 22, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 799
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Dynamic Batch Scheduler Implementation
Summary: This diff adds support of dynamic batch training in pytext. It creates a new batcher that computes the batch size depending on the current epoch. The diff implements two schedulers: * **Linear**: increases batch size linearly * **Exponential**: increases batch size exponentially # API The dynamic batcher extends the pooling batcher so that most the arguments are there and are pretty consistent. It is important to note that dynamic batch sizes only affects the training batch size, not eval or test. Dynamic batcher holds a new configuration object `scheduler_config`, this contains the information needed to compute dynamic batch sizes namely: ``` class SchedulerConfig(ModuleConfig): # the initial batch size used for training start_batch_size: int = 32 # the final or max batch size to use, any scheduler should # not go over this batch size end_batch_size: int = 256 # the number of epochs to increase the batch size over epoch_period: int = 10 # the batch size is kept constant for `step_size` number of epochs step_size: int = 1 ``` Paper: https://arxiv.org/abs/1711.00489 Reviewed By: seayoung1112, ArmenAg Differential Revision: D18900677 fbshipit-source-id: 5ec95a0c8b206974b5b3c5b00a52223808dd9466
- Loading branch information
1 parent
779ba2b
commit 598ea15
Showing
4 changed files
with
378 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
import math | ||
from typing import Iterable | ||
|
||
from pytext.common.constants import Stage | ||
from pytext.config.module_config import ModuleConfig | ||
|
||
from .data import PoolingBatcher | ||
from .sources import RawExample | ||
|
||
|
||
class DynamicPoolingBatcher(PoolingBatcher): | ||
""" | ||
Allows dynamic batch training, extends pooling batcher with a scheduler | ||
config, which specifies how batch size should increase | ||
""" | ||
|
||
class Config(PoolingBatcher.Config): | ||
class SchedulerConfig(ModuleConfig): | ||
# the initial batch size used for training, this is per node | ||
start_batch_size: int = 32 | ||
|
||
# the final or max batch size to use, any scheduler should | ||
# not go over this batch size, this is per node | ||
end_batch_size: int = 256 | ||
|
||
# the number of epochs to increase the batch size over | ||
epoch_period: int = 10 | ||
|
||
# the batch size is kept constant for `step_size` number of epochs | ||
step_size: int = 1 | ||
|
||
scheduler_config: SchedulerConfig = SchedulerConfig() | ||
|
||
@classmethod | ||
def from_config(cls, config: Config): | ||
return cls( | ||
config.train_batch_size, | ||
config.eval_batch_size, | ||
config.test_batch_size, | ||
config.pool_num_batches, | ||
config.num_shuffled_pools, | ||
config.scheduler_config, | ||
) | ||
|
||
def __init__( | ||
self, | ||
train_batch_size=Config.train_batch_size, | ||
eval_batch_size=Config.eval_batch_size, | ||
test_batch_size=Config.test_batch_size, | ||
pool_num_batches=Config.pool_num_batches, | ||
num_shuffled_pools=Config.num_shuffled_pools, | ||
scheduler_config=Config.scheduler_config, | ||
): | ||
super().__init__( | ||
train_batch_size, | ||
eval_batch_size, | ||
test_batch_size, | ||
pool_num_batches, | ||
num_shuffled_pools, | ||
) | ||
self.scheduler_config = scheduler_config | ||
self.curr_batch_size: int = 0 | ||
self.curr_epoch: int = -1 | ||
self.step_epoch() | ||
|
||
def compute_dynamic_batch_size( | ||
self, curr_epoch: int, scheduler_config: Config.SchedulerConfig, curr_steps: int | ||
) -> int: | ||
raise NotImplementedError() | ||
|
||
def step_epoch(self): | ||
self.curr_epoch += 1 | ||
if self.curr_epoch > self.scheduler_config.epoch_period: | ||
# if past the dynamic period just return | ||
# the batch size | ||
self.curr_batch_size = self.scheduler_config.end_batch_size | ||
else: | ||
if self.curr_epoch % self.scheduler_config.step_size == 0: | ||
old_size = self.curr_batch_size | ||
new_size = self.compute_dynamic_batch_size( | ||
curr_epoch=self.curr_epoch, | ||
scheduler_config=self.scheduler_config, | ||
curr_steps=self.curr_epoch // self.scheduler_config.step_size, | ||
) | ||
|
||
print(f"increasing batch size from {old_size} to {new_size}") | ||
|
||
self.curr_batch_size = new_size | ||
|
||
def finished_dynamic(self) -> bool: | ||
return ( | ||
self.scheduler_config.epoch_period != -1 | ||
and self.curr_epoch >= self.scheduler_config.epoch_period | ||
) | ||
|
||
def get_batch_size(self, stage: Stage) -> int: | ||
if stage == Stage.TRAIN: | ||
print(f"using dynamic batch size {self.curr_batch_size}") | ||
return self.curr_batch_size | ||
else: | ||
return self._batch_sizes[stage] | ||
|
||
def batchify( | ||
self, iterable: Iterable[RawExample], sort_key=None, stage=Stage.TRAIN | ||
): | ||
""" | ||
From an iterable of dicts, yield dicts of lists: | ||
1. Load `num_shuffled_pools` pools of data, and shuffle them. | ||
2. Load a pool (`batch_size * pool_num_batches` examples). | ||
3. Sort rows, if necessary. | ||
4. Shuffle the order in which the batches are returned, if necessary. | ||
""" | ||
for item in super().batchify(iterable=iterable, sort_key=sort_key, stage=stage): | ||
yield item | ||
if stage == Stage.TRAIN: | ||
# only step scheduler when in train | ||
self.step_epoch() | ||
|
||
|
||
class LinearDynamicPoolingBatcher(DynamicPoolingBatcher): | ||
""" | ||
Linear Dynamic Batch Scheduler: scales up batch size linearly | ||
""" | ||
|
||
def compute_dynamic_batch_size( | ||
self, | ||
curr_epoch: int, | ||
scheduler_config: DynamicPoolingBatcher.Config.SchedulerConfig, | ||
curr_steps: int, | ||
) -> int: | ||
batch_delta = ( | ||
scheduler_config.end_batch_size - scheduler_config.start_batch_size | ||
) | ||
curr_est: float = ( | ||
batch_delta / (scheduler_config.epoch_period / scheduler_config.step_size) | ||
) * curr_steps + scheduler_config.start_batch_size | ||
return math.ceil(curr_est) | ||
|
||
|
||
class ExponentialDynamicPoolingBatcher(DynamicPoolingBatcher): | ||
""" | ||
Exponential Dynamic Batch Scheduler: scales up batch size by a factor of | ||
gamma | ||
""" | ||
|
||
class Config(DynamicPoolingBatcher.Config): | ||
class SchedulerConfig(DynamicPoolingBatcher.Config.SchedulerConfig): | ||
gamma: float = 5 | ||
|
||
scheduler_config: SchedulerConfig | ||
|
||
def __init__(self, *args, **kwargs): | ||
self.max_steps = None | ||
super().__init__(*args, **kwargs) | ||
|
||
def get_max_steps(self): | ||
if self.max_steps: | ||
return self.max_steps | ||
self.max_steps: float = math.floor( | ||
math.log( | ||
self.scheduler_config.end_batch_size | ||
/ self.scheduler_config.start_batch_size | ||
) | ||
/ math.log(self.scheduler_config.gamma) | ||
) | ||
|
||
return self.max_steps | ||
|
||
def finished_dynamic(self) -> bool: | ||
return ( | ||
self.scheduler_config.epoch_period != -1 | ||
and self.curr_epoch >= self.get_max_steps() | ||
) | ||
|
||
def compute_dynamic_batch_size( | ||
self, curr_epoch: int, scheduler_config: Config.SchedulerConfig, curr_steps: int | ||
) -> int: | ||
if curr_steps > self.get_max_steps(): | ||
return scheduler_config.end_batch_size | ||
curr_est: float = scheduler_config.start_batch_size * math.pow( | ||
scheduler_config.gamma, curr_steps | ||
) | ||
return min(math.ceil(curr_est), scheduler_config.end_batch_size) |
Oops, something went wrong.