Skip to content

Commit

Permalink
Add Probabilistic Sampling of batches (facebookresearch#474)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#474

Adding a ProbabilisticBatchSampler which is able to sample batches using specified probabilities for the iterators and an optional epoch size. This is needed for training cross-lingual LMs.

Differential Revision: D14929861

fbshipit-source-id: 8c75eae936e5294ce0ce09a82d216d130cdbe9c4
  • Loading branch information
Kartikay Khandelwal authored and facebook-github-bot committed Apr 16, 2019
1 parent eaba52f commit 11149f1
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 11 deletions.
8 changes: 7 additions & 1 deletion pytext/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from .batch_sampler import BaseBatchSampler, EvalBatchSampler, RoundRobinBatchSampler
from .batch_sampler import (
BaseBatchSampler,
EvalBatchSampler,
ProbabalisticBatchSampler,
RoundRobinBatchSampler,
)
from .bptt_lm_data_handler import BPTTLanguageModelDataHandler
from .compositional_data_handler import CompositionalDataHandler
from .contextual_intent_slot_data_handler import ContextualIntentSlotModelDataHandler
Expand Down Expand Up @@ -36,6 +41,7 @@
"LanguageModelDataHandler",
"PairClassificationDataHandler",
"PoolingBatcher",
"ProbabalisticBatchSampler",
"QueryDocumentPairwiseRankingDataHandler",
"RawData",
"RoundRobinBatchSampler",
Expand Down
84 changes: 83 additions & 1 deletion pytext/data/batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from collections.abc import Iterator
from typing import Dict, Optional
from typing import Dict, List, Optional

import numpy as np
from pytext.config.component import Component, ComponentType


Expand Down Expand Up @@ -112,3 +114,83 @@ def batchify(self, iterators: Dict[str, Iterator]):
return
else:
yield name, next(new_iter)


class ProbabalisticBatchSampler(BaseBatchSampler):
"""
This sampler takes in a dictionary of iterators and returns batches according
to the specified probabilities in iterator_probabilities. The epoch_size
argument governs the end of the epoch. If epoch_size = -1, then the epoch
ends when the first iterator runs out of data. Otherwise we cycle through
the iterators until epoch_size number of batches have been generated.
Note that in this case the epoch_size refers to the number of batches in
the epoch.
Example:
Iterator 1: [A, B, C, D], Iterator 2: [a, b]
epoch_size = -1, iterator_probabilities = [0, 1]
Output = [a, b]
epoch_size = -1, iterator_probabilities = [1, 0]
Output = [A, B, C, D]
epoch_size = 4, iterator_probabilities = [0, 1]
Output = [a, b, a, b]
"""

__COMPONENT_TYPE__ = ComponentType.BATCH_SAMPLER

class Config(Component.Config):
iterator_probabilities: Dict[str, float]
epoch_size: int = -1

@classmethod
def from_config(cls, config: Config):
return cls(config.iterator_probabilities, config.epoch_size)

def __init__(
self, iterator_probabilities: List[float], epoch_size: int = -1
) -> None:
self.epoch_size = epoch_size
self.iterator_probabilities = iterator_probabilities

def batchify(self, iterators: Dict[str, Iterator]):
assert (
sum(self.iterator_probabilities.values()) == 1.0
), "Specified probabilities don't sum to 1."

for key in iterators.keys():
assert (
key in self.iterator_probabilities
), "Probability for iterator related to {} is missing".format(key)

iter_dict = {name: iter(iterator) for name, iterator in iterators.items()}
self.num_batches = 0

while True:

if self.num_batches > 0 and self.num_batches == self.epoch_size:
# end of epoch
return

# Select a candidate iterator using the uniform distribtion
candidates = list(iter_dict.keys())
selected_key = np.random.choice(
candidates, 1, p=list(self.iterator_probabilities.values())
).item()
self.num_batches += 1

try:
yield selected_key, next(iter_dict[selected_key])
except StopIteration:

# if epoch_size is -1 then this ends the epoch
if self.epoch_size == -1:
return
else:
new_iter = iter(iterators[selected_key])
iter_dict[selected_key] = new_iter
yield selected_key, next(new_iter)
43 changes: 34 additions & 9 deletions pytext/data/test/batch_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,57 @@

import unittest

from pytext.data import EvalBatchSampler, RoundRobinBatchSampler
from pytext.data import (
EvalBatchSampler,
ProbabalisticBatchSampler,
RoundRobinBatchSampler,
)


class BatchSamplerTest(unittest.TestCase):
def test_batch_sampler(self):
iteratorA = ["1", "2", "3", "4", "5"]
iteratorB = ["a", "b", "c"]
def setUp(self):
self.iteratorA = ["1", "2", "3", "4", "5"]
self.iteratorB = ["a", "b", "c"]
self.iter_dict = {"A": self.iteratorA, "B": self.iteratorB}

def test_round_robin_batch_sampler(self):

# no iter_to_set_epoch
round_robin_iterator = RoundRobinBatchSampler().batchify(
{"A": iteratorA, "B": iteratorB}
)
round_robin_iterator = RoundRobinBatchSampler().batchify(self.iter_dict)
expected_items = ["1", "a", "2", "b", "3", "c", "4"]
self._check_iterator(round_robin_iterator, expected_items)

# iter_to_set_epoch = "A"
round_robin_iterator = RoundRobinBatchSampler(iter_to_set_epoch="A").batchify(
{"A": iteratorA, "B": iteratorB}
self.iter_dict
)
expected_items = ["1", "a", "2", "b", "3", "c", "4", "a", "5", "b"]
self._check_iterator(round_robin_iterator, expected_items)

eval_iterator = EvalBatchSampler().batchify({"A": iteratorA, "B": iteratorB})
def test_eval_batch_sampler(self):
eval_iterator = EvalBatchSampler().batchify(self.iter_dict)
expected_items = ["1", "2", "3", "4", "5", "a", "b", "c"]
self._check_iterator(eval_iterator, expected_items)

def test_prob_batch_sampler(self):
prob_iterator = ProbabalisticBatchSampler(
iterator_probabilities={"A": 0, "B": 1}, epoch_size=-1
).batchify(self.iter_dict)
expected_items = ["a", "b", "c"]
self._check_iterator(prob_iterator, expected_items)

prob_iterator = ProbabalisticBatchSampler(
iterator_probabilities={"A": 1, "B": 0}, epoch_size=-1
).batchify(self.iter_dict)
expected_items = ["1", "2", "3", "4", "5"]
self._check_iterator(prob_iterator, expected_items)

prob_iterator = ProbabalisticBatchSampler(
iterator_probabilities={"A": 1, "B": 0}, epoch_size=10
).batchify(self.iter_dict)
expected_items = ["1", "2", "3", "4", "5", "1", "2", "3", "4", "5"]
self._check_iterator(prob_iterator, expected_items)

def _check_iterator(self, iterator, expected_items, fixed_order=True):
actual_items = [item for _, item in iterator]
if not fixed_order:
Expand Down

0 comments on commit 11149f1

Please sign in to comment.