diff --git a/src/nnet3bin/nnet3-egs-augment-image.cc b/src/nnet3bin/nnet3-egs-augment-image.cc index 99cd1249ffc..07f5d2e91de 100644 --- a/src/nnet3bin/nnet3-egs-augment-image.cc +++ b/src/nnet3bin/nnet3-egs-augment-image.cc @@ -28,22 +28,26 @@ namespace kaldi { namespace nnet3 { +enum FillMode { kNearest, kReflect }; + struct ImageAugmentationConfig { int32 num_channels; BaseFloat horizontal_flip_prob; BaseFloat horizontal_shift; BaseFloat vertical_shift; + std::string fill_mode_string; ImageAugmentationConfig(): num_channels(1), horizontal_flip_prob(0.0), horizontal_shift(0.0), - vertical_shift(0.0) { } + vertical_shift(0.0), + fill_mode_string("nearest") { } void Register(ParseOptions *po) { po->Register("num-channels", &num_channels, "Number of colors in the image." - "It is is important to specify this (helps interpret the image " + "It is important to specify this (helps interpret the image " "correctly."); po->Register("horizontal-flip-prob", &horizontal_flip_prob, "Probability of doing horizontal flip"); @@ -53,6 +57,9 @@ struct ImageAugmentationConfig { po->Register("vertical-shift", &vertical_shift, "Maximum allowed vertical shift as proportion of image " "height. Padding is with closest pixel."); + po->Register("fill-mode", &fill_mode_string, "Mode for dealing with " + "points outside the image boundary when applying transformation. " + "Choices = {nearest, reflect}"); } void Check() const { @@ -61,6 +68,22 @@ struct ImageAugmentationConfig { horizontal_flip_prob <= 1); KALDI_ASSERT(horizontal_shift >= 0 && horizontal_shift <= 1); KALDI_ASSERT(vertical_shift >= 0 && vertical_shift <= 1); + KALDI_ASSERT(fill_mode_string == "nearest" || fill_mode_string == "reflect"); + } + + FillMode GetFillMode() const { + FillMode fill_mode; + if (fill_mode_string == "reflect") { + fill_mode = kReflect; + } else { + if (fill_mode_string != "nearest") { + KALDI_ERR << "Choices for --fill-mode are 'nearest' or 'reflect', got: " + << fill_mode_string; + } else { + fill_mode = kNearest; + } + } + return fill_mode; } }; @@ -76,13 +99,15 @@ struct ImageAugmentationConfig { */ void ApplyAffineTransform(MatrixBase &transform, int32 num_channels, - MatrixBase *image) { + MatrixBase *image, + FillMode fill_mode) { int32 num_rows = image->NumRows(), num_cols = image->NumCols(), - height = num_cols / num_channels; + height = num_cols / num_channels, + width = num_rows; KALDI_ASSERT(num_cols % num_channels == 0); Matrix original_image(*image); - for (int32 r = 0; r < num_rows; r++) { + for (int32 r = 0; r < width; r++) { for (int32 c = 0; c < height; c++) { // (r_old, c_old) is the coordinate of the pixel in the original image // while (r, c) is the coordinate in the new (transformed) image. @@ -106,21 +131,41 @@ void ApplyAffineTransform(MatrixBase &transform, weight_21 = (r2 - r_old) * (c_old - c1), weight_22 = (r_old - r1) * (c_old - c1); // Handle edge conditions: - if (r1 < 0) { - r1 = 0; - if (r2 < 0) r2 = 0; - } - if (r2 >= num_rows) { - r2 = num_rows - 1; - if (r1 >= num_rows) r1 = num_rows - 1; - } - if (c1 < 0) { - c1 = 0; - if (c2 < 0) c2 = 0; - } - if (c2 >= num_cols) { - c2 = num_cols - 1; - if (c1 >= num_cols) c1 = num_cols - 1; + if (fill_mode == kNearest) { + if (r1 < 0) { + r1 = 0; + if (r2 < 0) r2 = 0; + } + if (r2 >= width) { + r2 = width - 1; + if (r1 >= width) r1 = width - 1; + } + if (c1 < 0) { + c1 = 0; + if (c2 < 0) c2 = 0; + } + if (c2 >= height) { + c2 = height - 1; + if (c1 >= height) c1 = height - 1; + } + } else { + KALDI_ASSERT(fill_mode == kReflect); + if (r1 < 0) { + r1 = - r1; + if (r2 < 0) r2 = - r2; + } + if (r2 >= width) { + r2 = 2 * width -2 - r2; + if (r1 >= width) r1 = 2 * width - 2 - r1; + } + if (c1 < 0) { + c1 = - c1; + if (c2 < 0) c2 = -c2; + } + if (c2 >= height) { + c2 = 2 * height - 2 - c2; + if (c1 >= height) c1 = 2 * height - 2 - c1; + } } for (int32 ch = 0; ch < num_channels; ch++) { // find the values at the 4 points @@ -150,6 +195,7 @@ void ApplyAffineTransform(MatrixBase &transform, void PerturbImage(const ImageAugmentationConfig &config, MatrixBase *image) { config.Check(); + FillMode fill_mode = config.GetFillMode(); int32 image_width = image->NumRows(), num_channels = config.num_channels, image_height = image->NumCols() / num_channels; @@ -227,7 +273,7 @@ void PerturbImage(const ImageAugmentationConfig &config, transform_mat.AddMatMatMat(1.0, set_origin_mat, kNoTrans, transform_mat, kNoTrans, reset_origin_mat, kNoTrans, 0.0); - ApplyAffineTransform(transform_mat, config.num_channels, image); + ApplyAffineTransform(transform_mat, config.num_channels, image, fill_mode); } @@ -279,7 +325,7 @@ int main(int argc, char *argv[]) { "parameters).\n" "E.g.:\n" " nnet3-egs-augment-image --horizontal-flip-prob=0.5 --horizontal-shift=0.1\\\n" - " --vertical-shift=0.1 --srand=103 --num-channels=3 ark:- ark:-\n" + " --vertical-shift=0.1 --srand=103 --num-channels=3 --fill-mode=nearest ark:- ark:-\n" "\n" "Requires that each eg contain a NnetIo object 'input', with successive\n" "'t' values representing different x offsets , and the feature dimension\n" @@ -294,6 +340,7 @@ int main(int argc, char *argv[]) { ParseOptions po(usage); po.Register("srand", &srand_seed, "Seed for the random number generator"); + config.Register(&po); po.Read(argc, argv); @@ -305,6 +352,7 @@ int main(int argc, char *argv[]) { exit(1); } + std::string examples_rspecifier = po.GetArg(1), examples_wspecifier = po.GetArg(2);