Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load all data before training for pool_size=-1 #59

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/ui/data_provider/pydataprovider2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ PaddlePaddle from a user defined function. Its parameters are:
* should_shuffle defines whether to shuffle data or not. By default, it is set
true during training, and false during testing.
* pool_size is the memory pool size (in sample number) in DataProvider.
-1 means no limit.
-1 means no limit and loading all data before training. If your original
data contains small number of files or not shuffled, you are not recommended
to set a different pool_size.
* can_over_batch_size defines whether PaddlePaddle can store little more
samples than pool_size. It is better to set True to avoid some deadlocks.
* calc_batch_size is a function define how to calculate batch size. This is
Expand Down
10 changes: 8 additions & 2 deletions paddle/gserver/dataproviders/PyDataProvider2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class PyDataProvider2 : public DataProvider {
if (!ok) {
this->poolSize_ = -1UL;
}
loadAll_ = poolSize_ == -1UL;
this->canOverBatchSize_ = self.getBoolAttr("can_over_batch_size");

calcBatchSize_.reset(self.getAttr("calc_batch_size"));
Expand Down Expand Up @@ -320,7 +321,9 @@ class PyDataProvider2 : public DataProvider {
CHECK(PyIter_Check(callingContexts_.back()));
}
DBG << "Create context done";
callingContextCreated_.wait();
if (!loadAll_) {
callingContextCreated_.wait();
}

PositionRandom p(skipShuffle_);

Expand Down Expand Up @@ -389,7 +392,9 @@ class PyDataProvider2 : public DataProvider {
}
poolActualSize_ = 0;
exit_ = false;
if (startNewThread && cache_->reset()) {
if (loadAll_) {
loadThread();
} else if (startNewThread && cache_->reset()) {
DBG << "Start new thread.";
loadThread_.reset(new std::thread([this] {
loadThread();
Expand All @@ -413,6 +418,7 @@ class PyDataProvider2 : public DataProvider {

PyObjectPtr instance_;
size_t poolSize_;
bool loadAll_;
bool canOverBatchSize_;
PyObjectPtr calcBatchSize_;
PyObjectPtr generator_;
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/trainer/PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __call__(self, obj, filename):
def provider(input_types=None, should_shuffle=True, pool_size=-1,
can_over_batch_size=True,
calc_batch_size=None,
cache=CacheType.NO_CACHE,
cache=CacheType.CACHE_PASS_IN_MEM,
init_hook=None, **kwargs):
"""
Provider decorator. Use it to make a function into PyDataProvider2 object.
Expand All @@ -133,6 +133,10 @@ def process(settings, file_name):
:param should_shuffle: True if data should shuffle.
:type should_shuffle: bool
:param pool_size: Max number of sample in data pool.
-1 means loading all data before training. If your
original data contains small number of files or not
shuffled, you are not recommended to set a different
pool_size.
:type pool_size: int
:param can_over_batch_size: True if paddle can return a mini-batch larger
than batch size in settings. It is useful when
Expand Down