Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Use CPUPinned context in ImageRecordIOParser2 #13980

Merged
merged 5 commits into from
Jan 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/io/image_iter_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,13 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
bool verbose;
/*! \brief partition the data into multiple parts */
int num_parts;
/*! \brief the index of the part will read*/
/*! \brief the index of the part will read */
int part_index;
/*! \brief the size of a shuffle chunk*/
/*! \brief device id used to create context for internal NDArray */
int device_id;
/*! \brief the size of a shuffle chunk */
size_t shuffle_chunk_size;
/*! \brief the seed for chunk shuffling*/
/*! \brief the seed for chunk shuffling */
int shuffle_chunk_seed;
/*! \brief random seed for augmentations */
dmlc::optional<int> seed_aug;
Expand Down Expand Up @@ -163,6 +165,11 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
.describe("Virtually partition the data into these many parts.");
DMLC_DECLARE_FIELD(part_index).set_default(0)
.describe("The *i*-th virtual partition to be read.");
DMLC_DECLARE_FIELD(device_id).set_default(0)
.describe("The device id used to create context for internal NDArray. "\
"Setting device_id to -1 will create Context::CPU(0). Setting "
"device_id to valid positive device id will create "
"Context::CPUPinned(device_id). Default is 0.");
DMLC_DECLARE_FIELD(shuffle_chunk_size).set_default(0)
.describe("The data shuffle buffer size in MB. Only valid if shuffle is true.");
DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0)
Expand Down
9 changes: 7 additions & 2 deletions src/io/iter_image_recordio_2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,14 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
shape_vec.push_back(param_.label_width);
TShape label_shape(shape_vec.begin(), shape_vec.end());

out->data.at(0) = NDArray(data_shape, Context::CPU(0), false,
auto ctx = Context::CPU(0);
auto dev_id = param_.device_id;
if (dev_id != -1) {
ctx = Context::CPUPinned(dev_id);
}
out->data.at(0) = NDArray(data_shape, ctx, false,
mshadow::DataType<DType>::kFlag);
out->data.at(1) = NDArray(label_shape, Context::CPU(0), false,
out->data.at(1) = NDArray(label_shape, ctx, false,
mshadow::DataType<real_t>::kFlag);
unit_size_[0] = param_.data_shape.Size();
unit_size_[1] = param_.label_width;
Expand Down