From e999a46a8ca1383e78a9178dc65fa91e6e656c26 Mon Sep 17 00:00:00 2001 From: Yuxi Hu Date: Fri, 25 Jan 2019 10:21:13 -0800 Subject: [PATCH] Use CPUPinned context in ImageRecordIOParser2 (#13980) (#13990) * create NDArray with CPUPinned context in ImageRecordIOParser2 * update document * use -1 device_id as an option to create CPU(0) context * retrigger CI * fix cpplint error --- src/io/image_iter_common.h | 13 ++++++++++--- src/io/iter_image_recordio_2.cc | 9 +++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/io/image_iter_common.h b/src/io/image_iter_common.h index a2324a4b5c5b..dac74427a33d 100644 --- a/src/io/image_iter_common.h +++ b/src/io/image_iter_common.h @@ -125,11 +125,13 @@ struct ImageRecParserParam : public dmlc::Parameter { bool verbose; /*! \brief partition the data into multiple parts */ int num_parts; - /*! \brief the index of the part will read*/ + /*! \brief the index of the part will read */ int part_index; - /*! \brief the size of a shuffle chunk*/ + /*! \brief device id used to create context for internal NDArray */ + int device_id; + /*! \brief the size of a shuffle chunk */ size_t shuffle_chunk_size; - /*! \brief the seed for chunk shuffling*/ + /*! \brief the seed for chunk shuffling */ int shuffle_chunk_seed; // declare parameters @@ -161,6 +163,11 @@ struct ImageRecParserParam : public dmlc::Parameter { .describe("Virtually partition the data into these many parts."); DMLC_DECLARE_FIELD(part_index).set_default(0) .describe("The *i*-th virtual partition to be read."); + DMLC_DECLARE_FIELD(device_id).set_default(0) + .describe("The device id used to create context for internal NDArray. "\ + "Setting device_id to -1 will create Context::CPU(0). Setting " + "device_id to valid positive device id will create " + "Context::CPUPinned(device_id). Default is 0."); DMLC_DECLARE_FIELD(shuffle_chunk_size).set_default(0) .describe("The data shuffle buffer size in MB. Only valid if shuffle is true."); DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0) diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc index b567c729736c..d5c149be288d 100644 --- a/src/io/iter_image_recordio_2.cc +++ b/src/io/iter_image_recordio_2.cc @@ -285,9 +285,14 @@ inline bool ImageRecordIOParser2::ParseNext(DataBatch *out) { shape_vec.push_back(param_.label_width); TShape label_shape(shape_vec.begin(), shape_vec.end()); - out->data.at(0) = NDArray(data_shape, Context::CPU(0), false, + auto ctx = Context::CPU(0); + auto dev_id = param_.device_id; + if (dev_id != -1) { + ctx = Context::CPUPinned(dev_id); + } + out->data.at(0) = NDArray(data_shape, ctx, false, mshadow::DataType::kFlag); - out->data.at(1) = NDArray(label_shape, Context::CPU(0), false, + out->data.at(1) = NDArray(label_shape, ctx, false, mshadow::DataType::kFlag); unit_size_[0] = param_.data_shape.Size(); unit_size_[1] = param_.label_width;