From fb98bd51c0083eceea6d8750112d243e469d0b1a Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 Apr 2016 10:08:28 -0700 Subject: [PATCH 1/3] [IO] Refactor augmeneter to support additional ones and chains of augmenters --- src/io/image_aug_default.cc | 303 +++++++++++++++++++++++++++++++ src/io/image_augmenter.h | 330 ++++++---------------------------- src/io/io.cc | 1 - src/io/iter_image_recordio.cc | 59 +++--- 4 files changed, 392 insertions(+), 301 deletions(-) create mode 100644 src/io/image_aug_default.cc diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc new file mode 100644 index 000000000000..83f1ca3d3d29 --- /dev/null +++ b/src/io/image_aug_default.cc @@ -0,0 +1,303 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file image_aug_default.cc + * \brief Default augmenter. + */ +#include +#include +#include +#include +#include +#include "./image_augmenter.h" +#include "../common/utils.h" + +#if MXNET_USE_OPENCV +// Registers +namespace dmlc { +DMLC_REGISTRY_ENABLE(::mxnet::io::ImageAugmenterReg); +} // namespace dmlc +#endif + +namespace mxnet { +namespace io { + +/*! \brief image augmentation parameters*/ +struct DefaultImageAugmentParam : public dmlc::Parameter { + /*! \brief whether we do random cropping */ + bool rand_crop; + /*! \brief whether we do nonrandom croping */ + int crop_y_start; + /*! \brief whether we do nonrandom croping */ + int crop_x_start; + /*! \brief [-max_rotate_angle, max_rotate_angle] */ + int max_rotate_angle; + /*! \brief max aspect ratio */ + float max_aspect_ratio; + /*! \brief random shear the image [-max_shear_ratio, max_shear_ratio] */ + float max_shear_ratio; + /*! \brief max crop size */ + int max_crop_size; + /*! \brief min crop size */ + int min_crop_size; + /*! \brief max scale ratio */ + float max_random_scale; + /*! \brief min scale_ratio */ + float min_random_scale; + /*! \brief min image size */ + float min_img_size; + /*! \brief max image size */ + float max_img_size; + /*! \brief max random in H channel */ + int random_h; + /*! \brief max random in S channel */ + int random_s; + /*! \brief max random in L channel */ + int random_l; + /*! \brief rotate angle */ + int rotate; + /*! \brief filled color while padding */ + int fill_value; + /*! \brief interpolation method 0-NN 1-bilinear 2-cubic 3-area 4-lanczos4 9-auto 10-rand */ + int inter_method; + /*! \brief shape of the image data*/ + TShape data_shape; + // declare parameters + DMLC_DECLARE_PARAMETER(DefaultImageAugmentParam) { + DMLC_DECLARE_FIELD(rand_crop).set_default(false) + .describe("Augmentation Param: Whether to random crop on the image"); + DMLC_DECLARE_FIELD(crop_y_start).set_default(-1) + .describe("Augmentation Param: Where to nonrandom crop on y."); + DMLC_DECLARE_FIELD(crop_x_start).set_default(-1) + .describe("Augmentation Param: Where to nonrandom crop on x."); + DMLC_DECLARE_FIELD(max_rotate_angle).set_default(0.0f) + .describe("Augmentation Param: rotated randomly in [-max_rotate_angle, max_rotate_angle]."); + DMLC_DECLARE_FIELD(max_aspect_ratio).set_default(0.0f) + .describe("Augmentation Param: denotes the max ratio of random aspect ratio augmentation."); + DMLC_DECLARE_FIELD(max_shear_ratio).set_default(0.0f) + .describe("Augmentation Param: denotes the max random shearing ratio."); + DMLC_DECLARE_FIELD(max_crop_size).set_default(-1) + .describe("Augmentation Param: Maximum crop size."); + DMLC_DECLARE_FIELD(min_crop_size).set_default(-1) + .describe("Augmentation Param: Minimum crop size."); + DMLC_DECLARE_FIELD(max_random_scale).set_default(1.0f) + .describe("Augmentation Param: Maxmum scale ratio."); + DMLC_DECLARE_FIELD(min_random_scale).set_default(1.0f) + .describe("Augmentation Param: Minimum scale ratio."); + DMLC_DECLARE_FIELD(max_img_size).set_default(1e10f) + .describe("Augmentation Param: Maxmum image size after resizing."); + DMLC_DECLARE_FIELD(min_img_size).set_default(0.0f) + .describe("Augmentation Param: Minimum image size after resizing."); + DMLC_DECLARE_FIELD(random_h).set_default(0) + .describe("Augmentation Param: Maximum value of H channel in HSL color space."); + DMLC_DECLARE_FIELD(random_s).set_default(0) + .describe("Augmentation Param: Maximum value of S channel in HSL color space."); + DMLC_DECLARE_FIELD(random_l).set_default(0) + .describe("Augmentation Param: Maximum value of L channel in HSL color space."); + DMLC_DECLARE_FIELD(rotate).set_default(-1.0f) + .describe("Augmentation Param: Rotate angle."); + DMLC_DECLARE_FIELD(fill_value).set_default(255) + .describe("Augmentation Param: Maximum value of illumination variation."); + DMLC_DECLARE_FIELD(data_shape) + .set_expect_ndim(3).enforce_nonzero() + .describe("Dataset Param: Shape of each instance generated by the DataIter."); + DMLC_DECLARE_FIELD(inter_method).set_default(1) + .describe("Augmentation Param: 0-NN 1-bilinear 2-cubic 3-area 4-lanczos4 9-auto 10-rand."); + } +}; + +DMLC_REGISTER_PARAMETER(DefaultImageAugmentParam); + +std::vector ListDefaultAugParams() { + return DefaultImageAugmentParam::__FIELDS__(); +} + +#if MXNET_USE_OPENCV + +#ifdef _MSC_VER +#define M_PI CV_PI +#endif +/*! \brief helper class to do image augmentation */ +class DefaultImageAugmenter : public ImageAugmenter { + public: + // contructor + DefaultImageAugmenter() { + rotateM_ = cv::Mat(2, 3, CV_32F); + } + void Init(const std::vector >& kwargs) override { + std::vector > kwargs_left; + kwargs_left = param_.InitAllowUnknown(kwargs); + for (size_t i = 0; i < kwargs_left.size(); i++) { + if (!strcmp(kwargs_left[i].first.c_str(), "rotate_list")) { + const char* val = kwargs_left[i].second.c_str(); + const char *end = val + strlen(val); + char buf[128]; + while (val < end) { + sscanf(val, "%[^,]", buf); + val += strlen(buf) + 1; + rotate_list_.push_back(atoi(buf)); + } + } + } + } + /*! + * \brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR 2-CV_INTER_CUBIC + * \ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for shrink, bilinear for others) 10-RAND + */ + int GetInterMethod(int inter_method, int old_width, int old_height, int new_width, + int new_height, common::RANDOM_ENGINE *prnd) { + if (inter_method == 9) { + if (new_width > old_width && new_height > old_height) { + return 2; // CV_INTER_CUBIC for enlarge + } else if (new_width rand_uniform_int(0, 4); + return rand_uniform_int(*prnd); + } else { + return inter_method; + } + } + cv::Mat Process(const cv::Mat &src, + common::RANDOM_ENGINE *prnd) override { + using mshadow::index_t; + cv::Mat res; + + // normal augmentation by affine transformation. + if (param_.max_rotate_angle > 0 || param_.max_shear_ratio > 0.0f + || param_.rotate > 0 || rotate_list_.size() > 0 || param_.max_random_scale != 1.0 + || param_.min_random_scale != 1.0 || param_.max_aspect_ratio != 0.0f + || param_.max_img_size != 1e10f || param_.min_img_size != 0.0f) { + std::uniform_real_distribution rand_uniform(0, 1); + // shear + float s = rand_uniform(*prnd) * param_.max_shear_ratio * 2 - param_.max_shear_ratio; + // rotate + int angle = std::uniform_int_distribution( + -param_.max_rotate_angle, param_.max_rotate_angle)(*prnd); + if (param_.rotate > 0) angle = param_.rotate; + if (rotate_list_.size() > 0) { + angle = rotate_list_[std::uniform_int_distribution(0, rotate_list_.size() - 1)(*prnd)]; + } + float a = cos(angle / 180.0 * M_PI); + float b = sin(angle / 180.0 * M_PI); + // scale + float scale = rand_uniform(*prnd) * + (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale; + // aspect ratio + float ratio = rand_uniform(*prnd) * + param_.max_aspect_ratio * 2 - param_.max_aspect_ratio + 1; + float hs = 2 * scale / (1 + ratio); + float ws = ratio * hs; + // new width and height + float new_width = std::max(param_.min_img_size, + std::min(param_.max_img_size, scale * src.cols)); + float new_height = std::max(param_.min_img_size, + std::min(param_.max_img_size, scale * src.rows)); + cv::Mat M(2, 3, CV_32F); + M.at(0, 0) = hs * a - s * b * ws; + M.at(1, 0) = -b * ws; + M.at(0, 1) = hs * b + s * a * ws; + M.at(1, 1) = a * ws; + float ori_center_width = M.at(0, 0) * src.cols + M.at(0, 1) * src.rows; + float ori_center_height = M.at(1, 0) * src.cols + M.at(1, 1) * src.rows; + M.at(0, 2) = (new_width - ori_center_width) / 2; + M.at(1, 2) = (new_height - ori_center_height) / 2; + CHECK((param_.inter_method >= 1 && param_.inter_method <= 4) || + (param_.inter_method >= 9 && param_.inter_method <= 10)) + << "invalid inter_method: valid value 0,1,2,3,9,10"; + int interpolation_method = GetInterMethod(param_.inter_method, + src.cols, src.rows, new_width, new_height, prnd); + cv::warpAffine(src, temp_, M, cv::Size(new_width, new_height), + interpolation_method, + cv::BORDER_CONSTANT, + cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); + res = temp_; + } else { + res = src; + } + + // crop logic + if (param_.max_crop_size != -1 || param_.min_crop_size != -1) { + CHECK(res.cols >= param_.max_crop_size && res.rows >= \ + param_.max_crop_size && param_.max_crop_size >= param_.min_crop_size) + << "input image size smaller than max_crop_size"; + index_t rand_crop_size = + std::uniform_int_distribution(param_.min_crop_size, param_.max_crop_size)(*prnd); + index_t y = res.rows - rand_crop_size; + index_t x = res.cols - rand_crop_size; + if (param_.rand_crop != 0) { + y = std::uniform_int_distribution(0, y)(*prnd); + x = std::uniform_int_distribution(0, x)(*prnd); + } else { + y /= 2; x /= 2; + } + cv::Rect roi(x, y, rand_crop_size, rand_crop_size); + int interpolation_method = GetInterMethod(param_.inter_method, rand_crop_size, rand_crop_size, + param_.data_shape[2], param_.data_shape[1], prnd); + cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1]) + , 0, 0, interpolation_method); + } else { + CHECK(static_cast(res.rows) >= param_.data_shape[1] + && static_cast(res.cols) >= param_.data_shape[2]) + << "input image size smaller than input shape"; + index_t y = res.rows - param_.data_shape[1]; + index_t x = res.cols - param_.data_shape[2]; + if (param_.rand_crop != 0) { + y = std::uniform_int_distribution(0, y)(*prnd); + x = std::uniform_int_distribution(0, x)(*prnd); + } else { + y /= 2; x /= 2; + } + cv::Rect roi(x, y, param_.data_shape[2], param_.data_shape[1]); + res = res(roi); + } + + // color space augmentation + if (param_.random_h != 0 || param_.random_s != 0 || param_.random_l != 0) { + std::uniform_real_distribution rand_uniform(0, 1); + cvtColor(res, res, CV_BGR2HLS); + int h = rand_uniform(*prnd) * param_.random_h * 2 - param_.random_h; + int s = rand_uniform(*prnd) * param_.random_s * 2 - param_.random_s; + int l = rand_uniform(*prnd) * param_.random_l * 2 - param_.random_l; + int temp[3] = {h, l, s}; + int limit[3] = {180, 255, 255}; + for (int i = 0; i < res.rows; ++i) { + for (int j = 0; j < res.cols; ++j) { + for (int k = 0; k < 3; ++k) { + int v = res.at(i, j)[k]; + v += temp[k]; + v = std::max(0, std::min(limit[k], v)); + res.at(i, j)[k] = v; + } + } + } + cvtColor(res, res, CV_HLS2BGR); + } + return res; + } + + private: + // temporal space + cv::Mat temp_; + // rotation param + cv::Mat rotateM_; + // parameters + DefaultImageAugmentParam param_; + /*! \brief list of possible rotate angle */ + std::vector rotate_list_; +}; + +ImageAugmenter* ImageAugmenter::Create(const std::string& name) { + return dmlc::Registry::Find(name)->body(); +} + +MXNET_REGISTER_IMAGE_AUGMENTER(aug_default) +.describe("default augmenter") +.set_body([]() { + return new DefaultImageAugmenter(); + }); +#endif // MXNET_USE_OPENCV +} // namespace io +} // namespace mxnet diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index a359b99d6246..fe0cfef951be 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -1,300 +1,88 @@ /*! * Copyright (c) 2015 by Contributors - * \file image_augmenter_opencv.hpp - * \brief threaded version of page iterator + * \file image_augmenter.h + * \brief Interface of opencv based image augmenter */ #ifndef MXNET_IO_IMAGE_AUGMENTER_H_ #define MXNET_IO_IMAGE_AUGMENTER_H_ +#include + #if MXNET_USE_OPENCV #include -#endif -#include +#include #include #include -#include -#include #include "../common/utils.h" namespace mxnet { namespace io { -/*! \brief image augmentation parameters*/ -struct ImageAugmentParam : public dmlc::Parameter { - /*! \brief whether we do random cropping */ - bool rand_crop; - /*! \brief whether we do nonrandom croping */ - int crop_y_start; - /*! \brief whether we do nonrandom croping */ - int crop_x_start; - /*! \brief [-max_rotate_angle, max_rotate_angle] */ - int max_rotate_angle; - /*! \brief max aspect ratio */ - float max_aspect_ratio; - /*! \brief random shear the image [-max_shear_ratio, max_shear_ratio] */ - float max_shear_ratio; - /*! \brief max crop size */ - int max_crop_size; - /*! \brief min crop size */ - int min_crop_size; - /*! \brief max scale ratio */ - float max_random_scale; - /*! \brief min scale_ratio */ - float min_random_scale; - /*! \brief min image size */ - float min_img_size; - /*! \brief max image size */ - float max_img_size; - /*! \brief max random in H channel */ - int random_h; - /*! \brief max random in S channel */ - int random_s; - /*! \brief max random in L channel */ - int random_l; - /*! \brief rotate angle */ - int rotate; - /*! \brief filled color while padding */ - int fill_value; - /*! \brief interpolation method 0-NN 1-bilinear 2-cubic 3-area 4-lanczos4 9-auto 10-rand */ - int inter_method; - /*! \brief shape of the image data*/ - TShape data_shape; - // declare parameters - DMLC_DECLARE_PARAMETER(ImageAugmentParam) { - DMLC_DECLARE_FIELD(rand_crop).set_default(false) - .describe("Augmentation Param: Whether to random crop on the image"); - DMLC_DECLARE_FIELD(crop_y_start).set_default(-1) - .describe("Augmentation Param: Where to nonrandom crop on y."); - DMLC_DECLARE_FIELD(crop_x_start).set_default(-1) - .describe("Augmentation Param: Where to nonrandom crop on x."); - DMLC_DECLARE_FIELD(max_rotate_angle).set_default(0.0f) - .describe("Augmentation Param: rotated randomly in [-max_rotate_angle, max_rotate_angle]."); - DMLC_DECLARE_FIELD(max_aspect_ratio).set_default(0.0f) - .describe("Augmentation Param: denotes the max ratio of random aspect ratio augmentation."); - DMLC_DECLARE_FIELD(max_shear_ratio).set_default(0.0f) - .describe("Augmentation Param: denotes the max random shearing ratio."); - DMLC_DECLARE_FIELD(max_crop_size).set_default(-1) - .describe("Augmentation Param: Maximum crop size."); - DMLC_DECLARE_FIELD(min_crop_size).set_default(-1) - .describe("Augmentation Param: Minimum crop size."); - DMLC_DECLARE_FIELD(max_random_scale).set_default(1.0f) - .describe("Augmentation Param: Maxmum scale ratio."); - DMLC_DECLARE_FIELD(min_random_scale).set_default(1.0f) - .describe("Augmentation Param: Minimum scale ratio."); - DMLC_DECLARE_FIELD(max_img_size).set_default(1e10f) - .describe("Augmentation Param: Maxmum image size after resizing."); - DMLC_DECLARE_FIELD(min_img_size).set_default(0.0f) - .describe("Augmentation Param: Minimum image size after resizing."); - DMLC_DECLARE_FIELD(random_h).set_default(0) - .describe("Augmentation Param: Maximum value of H channel in HSL color space."); - DMLC_DECLARE_FIELD(random_s).set_default(0) - .describe("Augmentation Param: Maximum value of S channel in HSL color space."); - DMLC_DECLARE_FIELD(random_l).set_default(0) - .describe("Augmentation Param: Maximum value of L channel in HSL color space."); - DMLC_DECLARE_FIELD(rotate).set_default(-1.0f) - .describe("Augmentation Param: Rotate angle."); - DMLC_DECLARE_FIELD(fill_value).set_default(255) - .describe("Augmentation Param: Maximum value of illumination variation."); - DMLC_DECLARE_FIELD(data_shape) - .set_expect_ndim(3).enforce_nonzero() - .describe("Dataset Param: Shape of each instance generated by the DataIter."); - DMLC_DECLARE_FIELD(inter_method).set_default(1) - .describe("Augmentation Param: 0-NN 1-bilinear 2-cubic 3-area 4-lanczos4 9-auto 10-rand."); - } -}; - -/*! \brief helper class to do image augmentation */ +/*! + * \brief OpenCV based Image augmenter, + * The augmenter can contain internal temp state. + */ class ImageAugmenter { public: - // contructor - ImageAugmenter(void) { -#if MXNET_USE_OPENCV - rotateM_ = cv::Mat(2, 3, CV_32F); -#endif - } - virtual ~ImageAugmenter() { - } - virtual void Init(const std::vector >& kwargs) { - std::vector > kwargs_left; - kwargs_left = param_.InitAllowUnknown(kwargs); - for (size_t i = 0; i < kwargs_left.size(); i++) { - if (!strcmp(kwargs_left[i].first.c_str(), "rotate_list")) { - const char* val = kwargs_left[i].second.c_str(); - const char *end = val + strlen(val); - char buf[128]; - while (val < end) { - sscanf(val, "%[^,]", buf); - val += strlen(buf) + 1; - rotate_list_.push_back(atoi(buf)); - } - } - } - } /*! - *\brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR 2-CV_INTER_CUBIC - *\ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for shrink, bilinear for others) 10-RAND + * \brief Initialize the Operator by setting the parameters + * This function need to be called before all other functions. + * \param kwargs the keyword arguments parameters */ - virtual int GetInterMethod(int inter_method, int old_width, int old_height, int new_width, - int new_height, common::RANDOM_ENGINE *prnd) { - if (inter_method == 9) { - if (new_width > old_width && new_height > old_height) { - return 2; // CV_INTER_CUBIC for enlarge - } else if (new_width rand_uniform_int(0, 4); - return rand_uniform_int(*prnd); - } else { - return inter_method; - } - } -#if MXNET_USE_OPENCV -#ifdef _MSC_VER -#define M_PI CV_PI -#endif + virtual void Init(const std::vector >& kwargs) = 0; /*! - * \brief augment src image, store result into dst + * \brief augment src image. * this function is not thread safe, and will only be called by one thread * however, it will tries to re-use memory space as much as possible * \param src the source image - * \param source of random number - * \param dst the pointer to the place where we want to store the result + * \param prnd pointer to random number generator. + * \return The processed image. */ virtual cv::Mat Process(const cv::Mat &src, - common::RANDOM_ENGINE *prnd) { - using mshadow::index_t; - cv::Mat res; - - // normal augmentation by affine transformation. - if (param_.max_rotate_angle > 0 || param_.max_shear_ratio > 0.0f - || param_.rotate > 0 || rotate_list_.size() > 0 || param_.max_random_scale != 1.0 - || param_.min_random_scale != 1.0 || param_.max_aspect_ratio != 0.0f - || param_.max_img_size != 1e10f || param_.min_img_size != 0.0f) { - std::uniform_real_distribution rand_uniform(0, 1); - // shear - float s = rand_uniform(*prnd) * param_.max_shear_ratio * 2 - param_.max_shear_ratio; - // rotate - int angle = std::uniform_int_distribution( - -param_.max_rotate_angle, param_.max_rotate_angle)(*prnd); - if (param_.rotate > 0) angle = param_.rotate; - if (rotate_list_.size() > 0) { - angle = rotate_list_[std::uniform_int_distribution(0, rotate_list_.size() - 1)(*prnd)]; - } - float a = cos(angle / 180.0 * M_PI); - float b = sin(angle / 180.0 * M_PI); - // scale - float scale = rand_uniform(*prnd) * - (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale; - // aspect ratio - float ratio = rand_uniform(*prnd) * - param_.max_aspect_ratio * 2 - param_.max_aspect_ratio + 1; - float hs = 2 * scale / (1 + ratio); - float ws = ratio * hs; - // new width and height - float new_width = std::max(param_.min_img_size, - std::min(param_.max_img_size, scale * src.cols)); - float new_height = std::max(param_.min_img_size, - std::min(param_.max_img_size, scale * src.rows)); - cv::Mat M(2, 3, CV_32F); - M.at(0, 0) = hs * a - s * b * ws; - M.at(1, 0) = -b * ws; - M.at(0, 1) = hs * b + s * a * ws; - M.at(1, 1) = a * ws; - float ori_center_width = M.at(0, 0) * src.cols + M.at(0, 1) * src.rows; - float ori_center_height = M.at(1, 0) * src.cols + M.at(1, 1) * src.rows; - M.at(0, 2) = (new_width - ori_center_width) / 2; - M.at(1, 2) = (new_height - ori_center_height) / 2; - CHECK((param_.inter_method >= 1 && param_.inter_method <= 4) || - (param_.inter_method >= 9 && param_.inter_method <= 10)) - << "invalid inter_method: valid value 0,1,2,3,9,10"; - int interpolation_method = GetInterMethod(param_.inter_method, - src.cols, src.rows, new_width, new_height, prnd); - cv::warpAffine(src, temp_, M, cv::Size(new_width, new_height), - interpolation_method, - cv::BORDER_CONSTANT, - cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); - res = temp_; - } else { - res = src; - } - - // crop logic - if (param_.max_crop_size != -1 || param_.min_crop_size != -1) { - CHECK(res.cols >= param_.max_crop_size && res.rows >= \ - param_.max_crop_size && param_.max_crop_size >= param_.min_crop_size) - << "input image size smaller than max_crop_size"; - index_t rand_crop_size = - std::uniform_int_distribution(param_.min_crop_size, param_.max_crop_size)(*prnd); - index_t y = res.rows - rand_crop_size; - index_t x = res.cols - rand_crop_size; - if (param_.rand_crop != 0) { - y = std::uniform_int_distribution(0, y)(*prnd); - x = std::uniform_int_distribution(0, x)(*prnd); - } else { - y /= 2; x /= 2; - } - cv::Rect roi(x, y, rand_crop_size, rand_crop_size); - int interpolation_method = GetInterMethod(param_.inter_method, rand_crop_size, rand_crop_size, - param_.data_shape[2], param_.data_shape[1], prnd); - cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1]) - , 0, 0, interpolation_method); - } else { - CHECK(static_cast(res.rows) >= param_.data_shape[1] - && static_cast(res.cols) >= param_.data_shape[2]) - << "input image size smaller than input shape"; - index_t y = res.rows - param_.data_shape[1]; - index_t x = res.cols - param_.data_shape[2]; - if (param_.rand_crop != 0) { - y = std::uniform_int_distribution(0, y)(*prnd); - x = std::uniform_int_distribution(0, x)(*prnd); - } else { - y /= 2; x /= 2; - } - cv::Rect roi(x, y, param_.data_shape[2], param_.data_shape[1]); - res = res(roi); - } - - // color space augmentation - if (param_.random_h != 0 || param_.random_s != 0 || param_.random_l != 0) { - std::uniform_real_distribution rand_uniform(0, 1); - cvtColor(res, res, CV_BGR2HLS); - int h = rand_uniform(*prnd) * param_.random_h * 2 - param_.random_h; - int s = rand_uniform(*prnd) * param_.random_s * 2 - param_.random_s; - int l = rand_uniform(*prnd) * param_.random_l * 2 - param_.random_l; - int temp[3] = {h, l, s}; - int limit[3] = {180, 255, 255}; - for (int i = 0; i < res.rows; ++i) { - for (int j = 0; j < res.cols; ++j) { - for (int k = 0; k < 3; ++k) { - int v = res.at(i, j)[k]; - v += temp[k]; - v = std::max(0, std::min(limit[k], v)); - res.at(i, j)[k] = v; - } - } - } - cvtColor(res, res, CV_HLS2BGR); - } - return res; - } - - -#endif + common::RANDOM_ENGINE *prnd) = 0; + // virtual destructor + virtual ~ImageAugmenter() {} + /*! + * \brief factory function + * \param name Name of the augmenter + * \return The created augmenter. + */ + static ImageAugmenter* Create(const std::string& name); +}; - private: -#if MXNET_USE_OPENCV - // temporal space - cv::Mat temp_; - // rotation param - cv::Mat rotateM_; -#endif - // parameters - ImageAugmentParam param_; - /*! \brief list of possible rotate angle */ - std::vector rotate_list_; +/*! \brief typedef the factory function of data iterator */ +typedef std::function ImageAugmenterFactory; +/*! + * \brief Registry entry for DataIterator factory functions. + */ +struct ImageAugmenterReg + : public dmlc::FunctionRegEntryBase { }; +//-------------------------------------------------------------- +// The following part are API Registration of Iterators +//-------------------------------------------------------------- +/*! + * \brief Macro to register image augmenter + * + * \code + * // example of registering a mnist iterator + * REGISTER_IMAGE_AUGMENTER(aug_default) + * .describe("default augmenter") + * .set_body([]() { + * return new DefaultAugmenter(); + * }); + * \endcode + */ +#define MXNET_REGISTER_IMAGE_AUGMENTER(name) \ + DMLC_REGISTRY_REGISTER(::mxnet::io::ImageAugmenterReg, ImageAugmenterReg, name) +} // namespace io +} // namespace mxnet +#endif // MXNET_USE_OPENCV + +namespace mxnet { +namespace io { +/*! \return the parameter of default augmenter */ +std::vector ListDefaultAugParams(); } // namespace io } // namespace mxnet #endif // MXNET_IO_IMAGE_AUGMENTER_H_ diff --git a/src/io/io.cc b/src/io/io.cc index b275664197c7..4251a96909c1 100644 --- a/src/io/io.cc +++ b/src/io/io.cc @@ -17,7 +17,6 @@ namespace io { // Register parameters in header files DMLC_REGISTER_PARAMETER(BatchParam); DMLC_REGISTER_PARAMETER(PrefetcherParam); -DMLC_REGISTER_PARAMETER(ImageAugmentParam); DMLC_REGISTER_PARAMETER(ImageNormalizeParam); } // namespace io } // namespace mxnet diff --git a/src/io/iter_image_recordio.cc b/src/io/iter_image_recordio.cc index 2db6d344c101..42ae5c89757c 100644 --- a/src/io/iter_image_recordio.cc +++ b/src/io/iter_image_recordio.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -93,6 +94,8 @@ struct ImageRecParserParam : public dmlc::Parameter { std::string path_imglist; /*! \brief path to image recordio */ std::string path_imgrec; + /*! \brief a sequence of names of image augmenters, seperated by , */ + std::string aug_seq; /*! \brief label-width */ int label_width; /*! \brief input shape */ @@ -112,6 +115,10 @@ struct ImageRecParserParam : public dmlc::Parameter { .describe("Dataset Param: Path to image list."); DMLC_DECLARE_FIELD(path_imgrec).set_default("./data/imgrec.rec") .describe("Dataset Param: Path to image record file."); + DMLC_DECLARE_FIELD(aug_seq).set_default("aug_default") + .describe("Augmentation Param: the augmenter names to represent"\ + " sequence of augmenters to be applied, seperated by comma." \ + " Additional keyword parameters will be seen by these augmenters."); DMLC_DECLARE_FIELD(label_width).set_lower_bound(1).set_default(1) .describe("Dataset Param: How many labels for an image."); DMLC_DECLARE_FIELD(data_shape) @@ -131,21 +138,6 @@ struct ImageRecParserParam : public dmlc::Parameter { // parser to parse image recordio class ImageRecordIOParser { public: - ImageRecordIOParser(void) - : source_(nullptr), - label_map_(nullptr) { - } - ~ImageRecordIOParser(void) { - // can be nullptr - delete label_map_; - delete source_; - for (size_t i = 0; i < augmenters_.size(); ++i) { - delete augmenters_[i]; - } - for (size_t i = 0; i < prnds_.size(); ++i) { - delete prnds_[i]; - } - } // initialize the parser inline void Init(const std::vector >& kwargs); @@ -162,20 +154,22 @@ class ImageRecordIOParser { static const int kRandMagic = 111; /*! \brief parameters */ ImageRecParserParam param_; + #if MXNET_USE_OPENCV /*! \brief augmenters */ - std::vector augmenters_; + std::vector > > augmenters_; + #endif /*! \brief random samplers */ - std::vector prnds_; + std::vector > prnds_; /*! \brief data source */ - dmlc::InputSplit *source_; + std::unique_ptr source_; /*! \brief label information, if any */ - ImageLabelMap *label_map_; + std::unique_ptr label_map_; /*! \brief temp space */ mshadow::TensorContainer img_; }; inline void ImageRecordIOParser::Init( - const std::vector >& kwargs) { + const std::vector >& kwargs) { #if MXNET_USE_OPENCV // initialize parameter // init image rec param @@ -193,15 +187,20 @@ inline void ImageRecordIOParser::Init( } param_.preprocess_threads = threadget; + std::vector aug_names = dmlc::Split(param_.aug_seq, ','); + augmenters_.clear(); + augmenters_.resize(threadget); // setup decoders for (int i = 0; i < threadget; ++i) { - augmenters_.push_back(new ImageAugmenter()); - augmenters_[i]->Init(kwargs); - prnds_.push_back(new common::RANDOM_ENGINE((i + 1) * kRandMagic)); + for (const auto& aug_name : aug_names) { + augmenters_[i].emplace_back(ImageAugmenter::Create(aug_name)); + augmenters_[i].back()->Init(kwargs); + } + prnds_.emplace_back(new common::RANDOM_ENGINE((i + 1) * kRandMagic)); } if (param_.path_imglist.length() != 0) { - label_map_ = new ImageLabelMap(param_.path_imglist.c_str(), - param_.label_width, !param_.verbose); + label_map_.reset(new ImageLabelMap(param_.path_imglist.c_str(), + param_.label_width, !param_.verbose)); } else { param_.label_width = 1; } @@ -212,9 +211,9 @@ inline void ImageRecordIOParser::Init( LOG(INFO) << "ImageRecordIOParser: " << param_.path_imgrec << ", use " << threadget << " threads for decoding.."; } - source_ = dmlc::InputSplit::Create( + source_.reset(dmlc::InputSplit::Create( param_.path_imgrec.c_str(), param_.part_index, - param_.num_parts, "recordio"); + param_.num_parts, "recordio")); // use 64 MB chunk when possible source_->HintChunkSize(8 << 20UL); #else @@ -248,7 +247,9 @@ ParseNext(std::vector *out_vec) { // -1 to keep the number of channel of the encoded image, and not force gray or color. res = cv::imdecode(buf, -1); const int n_channels = res.channels(); - res = augmenters_[tid]->Process(res, prnds_[tid]); + for (auto& aug : augmenters_[tid]) { + res = aug->Process(res, prnds_[tid].get()); + } out.Push(static_cast(rec.image_index()), mshadow::Shape3(n_channels, res.rows, res.cols), mshadow::Shape1(param_.label_width)); @@ -401,7 +402,7 @@ MXNET_REGISTER_IO_ITER(ImageRecordIter) .add_arguments(ImageRecordParam::__FIELDS__()) .add_arguments(BatchParam::__FIELDS__()) .add_arguments(PrefetcherParam::__FIELDS__()) -.add_arguments(ImageAugmentParam::__FIELDS__()) +.add_arguments(ListDefaultAugParams()) .add_arguments(ImageNormalizeParam::__FIELDS__()) .set_body([]() { return new PrefetcherIter( From c6a4f018be642e5baa6a3d71b46379ebdf3d3062 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 Apr 2016 10:09:06 -0700 Subject: [PATCH 2/3] update mshadow to latest --- mshadow | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mshadow b/mshadow index be90db115c87..f08c7b624a9b 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit be90db115c876c987bbc5957ea5e8b33e16ec420 +Subproject commit f08c7b624a9b2f7fb67052c2a39b264de5a9ab90 From 51365c0d2e559153dc4fb9cf0d6454b91f465941 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 Apr 2016 11:28:08 -0700 Subject: [PATCH 3/3] fix lint --- src/io/image_augmenter.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index fe0cfef951be..00d7ddd3fc70 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -10,9 +10,10 @@ #if MXNET_USE_OPENCV #include -#include -#include -#include +#include // NOLINT(*) +#include // NOLINT(*) +#include // NOLINT(*) + #include "../common/utils.h" namespace mxnet {