Skip to content
92 changes: 70 additions & 22 deletions src/nnet3bin/nnet3-egs-augment-image.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,26 @@
namespace kaldi {
namespace nnet3 {

enum FillMode { kNearest, kReflect };

struct ImageAugmentationConfig {
int32 num_channels;
BaseFloat horizontal_flip_prob;
BaseFloat horizontal_shift;
BaseFloat vertical_shift;
std::string fill_mode_string;

ImageAugmentationConfig():
num_channels(1),
horizontal_flip_prob(0.0),
horizontal_shift(0.0),
vertical_shift(0.0) { }
vertical_shift(0.0),
fill_mode_string("nearest") { }


void Register(ParseOptions *po) {
po->Register("num-channels", &num_channels, "Number of colors in the image."
"It is is important to specify this (helps interpret the image "
"It is important to specify this (helps interpret the image "
"correctly.");
po->Register("horizontal-flip-prob", &horizontal_flip_prob,
"Probability of doing horizontal flip");
Expand All @@ -53,6 +57,9 @@ struct ImageAugmentationConfig {
po->Register("vertical-shift", &vertical_shift,
"Maximum allowed vertical shift as proportion of image "
"height. Padding is with closest pixel.");
po->Register("fill-mode", &fill_mode_string, "Mode for dealing with "
"points outside the image boundary when applying transformation. "
"Choices = {nearest, reflect}");
}

void Check() const {
Expand All @@ -61,6 +68,22 @@ struct ImageAugmentationConfig {
horizontal_flip_prob <= 1);
KALDI_ASSERT(horizontal_shift >= 0 && horizontal_shift <= 1);
KALDI_ASSERT(vertical_shift >= 0 && vertical_shift <= 1);
KALDI_ASSERT(fill_mode_string == "nearest" || fill_mode_string == "reflect");
}

FillMode GetFillMode() const {
FillMode fill_mode;
if (fill_mode_string == "reflect") {
fill_mode = kReflect;
} else {
if (fill_mode_string != "nearest") {
KALDI_ERR << "Choices for --fill-mode are 'nearest' or 'reflect', got: "
<< fill_mode_string;
} else {
fill_mode = kNearest;
}
}
return fill_mode;
}
};

Expand All @@ -76,13 +99,15 @@ struct ImageAugmentationConfig {
*/
void ApplyAffineTransform(MatrixBase<BaseFloat> &transform,
int32 num_channels,
MatrixBase<BaseFloat> *image) {
MatrixBase<BaseFloat> *image,
FillMode fill_mode) {
int32 num_rows = image->NumRows(),
num_cols = image->NumCols(),
height = num_cols / num_channels;
height = num_cols / num_channels,
width = num_rows;
KALDI_ASSERT(num_cols % num_channels == 0);
Matrix<BaseFloat> original_image(*image);
for (int32 r = 0; r < num_rows; r++) {
for (int32 r = 0; r < width; r++) {
for (int32 c = 0; c < height; c++) {
// (r_old, c_old) is the coordinate of the pixel in the original image
// while (r, c) is the coordinate in the new (transformed) image.
Expand All @@ -106,21 +131,41 @@ void ApplyAffineTransform(MatrixBase<BaseFloat> &transform,
weight_21 = (r2 - r_old) * (c_old - c1),
weight_22 = (r_old - r1) * (c_old - c1);
// Handle edge conditions:
if (r1 < 0) {
r1 = 0;
if (r2 < 0) r2 = 0;
}
if (r2 >= num_rows) {
r2 = num_rows - 1;
if (r1 >= num_rows) r1 = num_rows - 1;
}
if (c1 < 0) {
c1 = 0;
if (c2 < 0) c2 = 0;
}
if (c2 >= num_cols) {
c2 = num_cols - 1;
if (c1 >= num_cols) c1 = num_cols - 1;
if (fill_mode == kNearest) {
if (r1 < 0) {
r1 = 0;
if (r2 < 0) r2 = 0;
}
if (r2 >= width) {
r2 = width - 1;
if (r1 >= width) r1 = width - 1;
}
if (c1 < 0) {
c1 = 0;
if (c2 < 0) c2 = 0;
}
if (c2 >= height) {
c2 = height - 1;
if (c1 >= height) c1 = height - 1;
}
} else {
KALDI_ASSERT(fill_mode == kReflect);
if (r1 < 0) {
r1 = - r1;
if (r2 < 0) r2 = - r2;
}
if (r2 >= width) {
r2 = 2 * width -2 - r2;
if (r1 >= width) r1 = 2 * width - 2 - r1;
}
if (c1 < 0) {
c1 = - c1;
if (c2 < 0) c2 = -c2;
}
if (c2 >= height) {
c2 = 2 * height - 2 - c2;
if (c1 >= height) c1 = 2 * height - 2 - c1;
}
}
for (int32 ch = 0; ch < num_channels; ch++) {
// find the values at the 4 points
Expand Down Expand Up @@ -150,6 +195,7 @@ void ApplyAffineTransform(MatrixBase<BaseFloat> &transform,
void PerturbImage(const ImageAugmentationConfig &config,
MatrixBase<BaseFloat> *image) {
config.Check();
FillMode fill_mode = config.GetFillMode();
int32 image_width = image->NumRows(),
num_channels = config.num_channels,
image_height = image->NumCols() / num_channels;
Expand Down Expand Up @@ -227,7 +273,7 @@ void PerturbImage(const ImageAugmentationConfig &config,
transform_mat.AddMatMatMat(1.0, set_origin_mat, kNoTrans,
transform_mat, kNoTrans,
reset_origin_mat, kNoTrans, 0.0);
ApplyAffineTransform(transform_mat, config.num_channels, image);
ApplyAffineTransform(transform_mat, config.num_channels, image, fill_mode);
}


Expand Down Expand Up @@ -279,7 +325,7 @@ int main(int argc, char *argv[]) {
"parameters).\n"
"E.g.:\n"
" nnet3-egs-augment-image --horizontal-flip-prob=0.5 --horizontal-shift=0.1\\\n"
" --vertical-shift=0.1 --srand=103 --num-channels=3 ark:- ark:-\n"
" --vertical-shift=0.1 --srand=103 --num-channels=3 --fill-mode=nearest ark:- ark:-\n"
"\n"
"Requires that each eg contain a NnetIo object 'input', with successive\n"
"'t' values representing different x offsets , and the feature dimension\n"
Expand All @@ -294,6 +340,7 @@ int main(int argc, char *argv[]) {

ParseOptions po(usage);
po.Register("srand", &srand_seed, "Seed for the random number generator");

config.Register(&po);

po.Read(argc, argv);
Expand All @@ -305,6 +352,7 @@ int main(int argc, char *argv[]) {
exit(1);
}


std::string examples_rspecifier = po.GetArg(1),
examples_wspecifier = po.GetArg(2);

Expand Down