Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
create NDArray with CPUPinned context in ImageRecordIOParser2
Browse files Browse the repository at this point in the history
  • Loading branch information
yuxihu committed Jan 24, 2019
1 parent f033aec commit b5c3434
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
10 changes: 7 additions & 3 deletions src/io/image_iter_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,13 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
bool verbose;
/*! \brief partition the data into multiple parts */
int num_parts;
/*! \brief the index of the part will read*/
/*! \brief the index of the part will read */
int part_index;
/*! \brief the size of a shuffle chunk*/
/*! \brief device id on which to store the data */
int device_id;
/*! \brief the size of a shuffle chunk */
size_t shuffle_chunk_size;
/*! \brief the seed for chunk shuffling*/
/*! \brief the seed for chunk shuffling */
int shuffle_chunk_seed;
/*! \brief random seed for augmentations */
dmlc::optional<int> seed_aug;
Expand Down Expand Up @@ -163,6 +165,8 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
.describe("Virtually partition the data into these many parts.");
DMLC_DECLARE_FIELD(part_index).set_default(0)
.describe("The *i*-th virtual partition to be read.");
DMLC_DECLARE_FIELD(device_id).set_default(0)
.describe("The device id on which to store the data.");
DMLC_DECLARE_FIELD(shuffle_chunk_size).set_default(0)
.describe("The data shuffle buffer size in MB. Only valid if shuffle is true.");
DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0)
Expand Down
5 changes: 3 additions & 2 deletions src/io/iter_image_recordio_2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,10 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
shape_vec.push_back(param_.label_width);
TShape label_shape(shape_vec.begin(), shape_vec.end());

out->data.at(0) = NDArray(data_shape, Context::CPU(0), false,
auto dev_id = param_.device_id;
out->data.at(0) = NDArray(data_shape, Context::CPUPinned(dev_id), false,
mshadow::DataType<DType>::kFlag);
out->data.at(1) = NDArray(label_shape, Context::CPU(0), false,
out->data.at(1) = NDArray(label_shape, Context::CPUPinned(dev_id), false,
mshadow::DataType<real_t>::kFlag);
unit_size_[0] = param_.data_shape.Size();
unit_size_[1] = param_.label_width;
Expand Down

0 comments on commit b5c3434

Please sign in to comment.