diff --git a/src/nnet3/convolution.cc b/src/nnet3/convolution.cc index 287ab7f47dd..1c5396949f8 100644 --- a/src/nnet3/convolution.cc +++ b/src/nnet3/convolution.cc @@ -976,7 +976,7 @@ static void ComputeTempMatrixSize(const ConvolutionComputationOptions &opts, // work out how many rows the temporary matrix should have, taking // into account the specified memory limit. temp_rows = computation->num_t_out * computation->num_images; - BaseFloat num_megabytes = (4 * temp_rows * temp_cols) / 1000000.0, + BaseFloat num_megabytes = (4 * (temp_rows / 1000.0) * (temp_cols / 1000.0)), megabyte_limit = opts.max_memory_mb; // C++ rounds down; here, we want to round up so we add one. int32 ratio = 1.0 + num_megabytes / megabyte_limit; @@ -986,7 +986,7 @@ static void ComputeTempMatrixSize(const ConvolutionComputationOptions &opts, // >= temp_rows so that we don't have a small leftover piece. int32 new_num_t_out = (computation->num_t_out + ratio - 1) / ratio; temp_rows = new_num_t_out * computation->num_images; - BaseFloat new_num_megabytes = (4 * temp_rows * temp_cols) / 1000000.0; + BaseFloat new_num_megabytes = (4 * (temp_rows / 1000.0) * (temp_cols / 1000.0)); // make sure we're within the memory limit. if (new_num_megabytes > 1.01 * megabyte_limit) { KALDI_WARN << "Memory consumed in convolution is more than requested "