From 098bc49f1d288ea9f2b64453aefcc1537ca5254e Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 4 Feb 2019 16:56:14 -0800 Subject: [PATCH] Avoid underflow of lp pooling calc for dtype=float16. --- src/operator/nn/pool.cuh | 23 ++++++++------ src/operator/nn/pool.h | 46 +++++++++++++++------------ src/operator/nn/pool_utils.h | 20 +++++++++--- tests/python/gpu/test_operator_gpu.py | 23 +++++++++++--- 4 files changed, 73 insertions(+), 39 deletions(-) diff --git a/src/operator/nn/pool.cuh b/src/operator/nn/pool.cuh index 69c630508304..671bc7932ef9 100644 --- a/src/operator/nn/pool.cuh +++ b/src/operator/nn/pool.cuh @@ -227,6 +227,7 @@ __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data, const int stride_w, const int pad_w, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; CUDA_KERNEL_LOOP(index, nthreads) { const bool nwc_layout = layout == mshadow::kNWC; const int idx = nwc_layout ? (index / channels) : index; @@ -241,14 +242,14 @@ __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data, if (get_avg && !count_include_pad) { pool_size = (wend - wstart); } - DType sum = 0; + AccType sum = 0; const DType* out_slice = nwc_layout ? in_data + n * channels * width + c : in_data + (n * channels + c) * width; const int multiplier = nwc_layout ? channels : 1; for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(out_slice[w * multiplier]) / pool_size; + sum += a_pow_p::Map(out_slice[w * multiplier]) / pool_size; } - out_data[index] = a_root_p::Map(sum); + out_data[index] = a_root_p::Map(sum); } } @@ -265,6 +266,7 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, const int pad_h, const int pad_w, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; CUDA_KERNEL_LOOP(index, nthreads) { const bool nhwc_layout = layout == mshadow::kNHWC; const int idx = nhwc_layout ? (index / channels) : index; @@ -285,16 +287,16 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, if (get_avg && !count_include_pad) { pool_size = (hend - hstart) * (wend - wstart); } - DType sum = 0; + AccType sum = 0; const DType* out_slice = nhwc_layout ? in_data + n * channels * height * width + c : in_data + (n * channels + c) * height * width; const int multiplier = nhwc_layout ? channels : 1; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(out_slice[(h * width + w) * multiplier]) / pool_size; + sum += a_pow_p::Map(out_slice[(h * width + w) * multiplier]) / pool_size; } } - out_data[index] = a_root_p::Map(sum); + out_data[index] = a_root_p::Map(sum); } } @@ -312,6 +314,7 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, const int pad_d, const int pad_h, const int pad_w, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; CUDA_KERNEL_LOOP(index, nthreads) { const bool ndhwc_layout = layout == mshadow::kNDHWC; const int idx = ndhwc_layout ? (index / channels) : index; @@ -337,21 +340,21 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, if (get_avg && !count_include_pad) { pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); } - DType sum = 0; + AccType sum = 0; const DType* out_slice = ndhwc_layout ? in_data + n * channels * depth * height * width + c : in_data + (n * channels + c) * depth * height * width; const int multiplier = ndhwc_layout ? channels : 1; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(out_slice[((d * height + h) * width + w) * + sum += a_pow_p::Map(out_slice[((d * height + h) * width + w) * multiplier]) / pool_size; } } } out_data[index] = (pool_size == 0) ? - DType(nanf("")) : - a_root_p::Map(sum); + AccType(nanf("")) : + a_root_p::Map(sum); } } diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h index 9cd779134228..3c8c19a02607 100644 --- a/src/operator/nn/pool.h +++ b/src/operator/nn/pool.h @@ -361,6 +361,7 @@ inline void pool_sum_1d_ncw_cpu(const DType *in_data, const TShape &ishape, cons const TShape &kernel, const TShape &pad, const TShape &stride, DType *out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; const int width = ishape[2]; const int pooled_width = oshape[2]; const int kernel_w = kernel[0]; @@ -379,11 +380,11 @@ inline void pool_sum_1d_ncw_cpu(const DType *in_data, const TShape &ishape, cons if (get_avg && !count_include_pad) { pool_size = (wend - wstart); } - DType sum = 0; + AccType sum = 0; for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(in_data[w]) / pool_size; + sum += a_pow_p::Map(in_data[w]) / pool_size; } - out_data[pw] = a_root_p::Map(sum); + out_data[pw] = a_root_p::Map(sum); } in_data += in_data_offset; out_data += out_data_offset; @@ -400,6 +401,7 @@ inline void pool_sum_1d_nwc_cpu(const DType* in_data, const TShape& ishape, cons const TShape& kernel, const TShape& pad, const TShape& stride, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; const int width = ishape[1]; const int pooled_width = oshape[1]; const int kernel_w = kernel[0]; @@ -408,7 +410,7 @@ inline void pool_sum_1d_nwc_cpu(const DType* in_data, const TShape& ishape, cons const int features = oshape[2]; const index_t in_data_offset = ishape[1] * features; const index_t out_data_offset = oshape[1] * features; - std::vector sums(features); + std::vector sums(features); for (index_t n = 0; n < oshape[0]; ++n) { for (int pw = 0; pw < pooled_width; ++pw) { int wstart = pw * stride_w - pad_w; @@ -422,11 +424,11 @@ inline void pool_sum_1d_nwc_cpu(const DType* in_data, const TShape& ishape, cons std::fill(sums.begin(), sums.end(), 0); for (int w = wstart; w < wend; ++w) { for (index_t c = 0; c < features; ++c) { - sums[c] += a_pow_p::Map(in_data[w * features + c]) / pool_size; + sums[c] += a_pow_p::Map(in_data[w * features + c]) / pool_size; } } for (index_t c = 0; c < features; ++c) - out_data[pw * features + c] = a_root_p::Map(sums[c]); + out_data[pw * features + c] = a_root_p::Map(sums[c]); } in_data += in_data_offset; out_data += out_data_offset; @@ -442,6 +444,7 @@ inline void pool_sum_2d_nchw_cpu(const DType *in_data, const TShape &ishape, con const TShape &kernel, const TShape &pad, const TShape &stride, DType *out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; const int height = ishape[2], width = ishape[3]; const int pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -465,13 +468,13 @@ inline void pool_sum_2d_nchw_cpu(const DType *in_data, const TShape &ishape, con if (get_avg && !count_include_pad) { pool_size = (hend - hstart) * (wend - wstart); } - DType sum = 0; + AccType sum = 0; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(in_data[h*width+w]) / pool_size; + sum += a_pow_p::Map(in_data[h*width+w]) / pool_size; } } - out_data[ph*pooled_width+pw] = a_root_p::Map(sum); + out_data[ph*pooled_width+pw] = a_root_p::Map(sum); } } in_data += in_data_offset; @@ -489,6 +492,7 @@ inline void pool_sum_2d_nhwc_cpu(const DType* in_data, const TShape& ishape, con const TShape& kernel, const TShape& pad, const TShape& stride, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; const int height = ishape[1], width = ishape[2]; const int pooled_height = oshape[1], pooled_width = oshape[2]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -497,7 +501,7 @@ inline void pool_sum_2d_nhwc_cpu(const DType* in_data, const TShape& ishape, con const int features = oshape[3]; const index_t in_data_offset = ishape[1] * ishape[2] * features; const index_t out_data_offset = oshape[1] * oshape[2] * features; - std::vector sums(features); + std::vector sums(features); for (index_t n = 0; n < oshape[0]; ++n) { for (int ph = 0; ph < pooled_height; ++ph) { for (int pw = 0; pw < pooled_width; ++pw) { @@ -519,12 +523,12 @@ inline void pool_sum_2d_nhwc_cpu(const DType* in_data, const TShape& ishape, con for (int w = wstart; w < wend; ++w) { const int in_index = h * width + w; for (index_t c = 0; c < features; ++c) { - sums[c] += a_pow_p::Map(in_data[in_index * features + c]) / pool_size; + sums[c] += a_pow_p::Map(in_data[in_index * features + c]) / pool_size; } } } for (index_t c = 0; c < features; ++c) - out_data[pool_index * features + c] = a_root_p::Map(sums[c]); + out_data[pool_index * features + c] = a_root_p::Map(sums[c]); } } in_data += in_data_offset; @@ -541,6 +545,7 @@ inline void pool_sum_3d_ncdhw_cpu(const DType *in_data, const TShape &ishape, co const TShape &kernel, const TShape &pad, const TShape &stride, DType *out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; const int depth = ishape[2], height = ishape[3], width = ishape[4]; const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -569,17 +574,17 @@ inline void pool_sum_3d_ncdhw_cpu(const DType *in_data, const TShape &ishape, co if (get_avg && !count_include_pad) { pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); } - DType sum = 0; + AccType sum = 0; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - sum += a_pow_p::Map(in_data[(d*height+h)*width+w]) / pool_size; + sum += a_pow_p::Map(in_data[(d*height+h)*width+w]) / pool_size; } } } out_data[(pd*pooled_height+ph)*pooled_width+pw] = (pool_size == 0) ? - DType(nanf("")) : - a_root_p::Map(sum); + AccType(nanf("")) : + a_root_p::Map(sum); } } } @@ -598,6 +603,7 @@ inline void pool_sum_3d_ndhwc_cpu(const DType* in_data, const TShape& ishape, co const TShape& kernel, const TShape& pad, const TShape& stride, DType* out_data, const bool get_avg = false, const bool count_include_pad = true) { + using AccType = typename PoolingTypes::AccType; const int depth = ishape[1], height = ishape[2], width = ishape[3]; const int pooled_depth = oshape[1], pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -606,7 +612,7 @@ inline void pool_sum_3d_ndhwc_cpu(const DType* in_data, const TShape& ishape, co const int features = oshape[4]; const index_t in_data_offset = ishape[1] * ishape[2] * ishape[3] * features; const index_t out_data_offset = oshape[1] * oshape[2] * oshape[3] * features; - std::vector sums(features); + std::vector sums(features); for (index_t n = 0; n < oshape[0]; ++n) { for (int pd = 0; pd < pooled_depth; ++pd) { for (int ph = 0; ph < pooled_height; ++ph) { @@ -634,15 +640,15 @@ inline void pool_sum_3d_ndhwc_cpu(const DType* in_data, const TShape& ishape, co for (int w = wstart; w < wend; ++w) { const int in_index = (d * height + h) * width + w; for (index_t c = 0; c < features; ++c) { - sums[c] += a_pow_p::Map(in_data[in_index * features + c]) / pool_size; + sums[c] += a_pow_p::Map(in_data[in_index * features + c]) / pool_size; } } } } for (index_t c = 0; c < features; ++c) out_data[pool_index * features + c] = (pool_size == 0) ? - DType(nanf("")) : - a_root_p::Map(sums[c]); + AccType(nanf("")) : + a_root_p::Map(sums[c]); } } } diff --git a/src/operator/nn/pool_utils.h b/src/operator/nn/pool_utils.h index 7b2657451ba7..6bf7235048dc 100644 --- a/src/operator/nn/pool_utils.h +++ b/src/operator/nn/pool_utils.h @@ -25,6 +25,17 @@ namespace mxnet { namespace op { +// Define an accumulator type AccType to permit float16-I/O lp pooling to avoid underflow. +template +struct PoolingTypes { + typedef DType AccType; +}; + +template<> +struct PoolingTypes { + typedef float AccType; +}; + template struct a_pow_p { static MSHADOW_XINLINE DType Map(const DType a) { @@ -98,16 +109,17 @@ struct lp_grad { template struct lp_grad { static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) { - // Avoid nan result if both grad and out_data are 0. - return (grad == DType(0.0)) ? DType(0.0) : grad * in_data / out_data; + // Avoid inf, if out_data has underflowed to 0 for a non-zero input, or nan if grad is also 0. + return (out_data == DType(0.0)) ? DType(0.0) : grad * (in_data / out_data); } }; template struct lp_grad { static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) { - // Avoid nan result if both grad and out_data are 0. - return (grad == DType(0.0)) ? DType(0.0) : grad * in_data * in_data / (out_data * out_data); + // Avoid inf, if out_data has underflowed to 0 for a non-zero input, or nan if grad is also 0. + DType in_out_ratio = in_data / out_data; + return (out_data == DType(0.0)) ? DType(0.0) : grad * in_out_ratio * in_out_ratio; } }; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index af1e15f3bfa0..ccd94263db15 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -935,7 +935,8 @@ def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, str check_consistency(sym_list, ctx_list, equal_nan=(not count_include_pad), tol=tol) - def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_include_pad=True): + def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_include_pad=True, + tol=None): if dim == '1D': data = (3, 3, 10) kernel = (4,) @@ -966,7 +967,7 @@ def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_inclu data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, pooling_convention=pooling_convention, global_pool=False, p_value=p_value, - count_include_pad=count_include_pad, dtype=dtype) + count_include_pad=count_include_pad, tol=tol, dtype=dtype) except: print('pool_op_list = {}'.format(pool_op_list)) print('kernel={}, pad={}, stride={}'.format(kernel, pad, stride)) @@ -980,7 +981,7 @@ def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_inclu test_pooling_versions_helper(pool_op_list=pool_op_list, data=data, kernel=kernel, pad=None, stride=None, pool_type=pool_type, global_pool=True, p_value=p_value, - count_include_pad=count_include_pad, dtype=dtype) + count_include_pad=count_include_pad, tol=tol, dtype=dtype) # The various implementations of the standard pooling operator std_pool_op_list = ['pool_cpu', 'pool_transposed_cpu', @@ -1017,15 +1018,27 @@ def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_inclu # only input values should appear in the output. # 3. In avg pooling, the 'v1' operator divides the sum by the same window size factor, # even at the edges, and so does not support count_include_pad = False. + # 4. The float16 'v1' pooling operator performs forward sums and averages in + # float16, whereas the std operators perform those calculations in float32, so + # greater float16 tolerances are needed when comparing across implementations. + + # Double the float16 tol when comparing v1 and non-v1 implemenations, per note 4 above. + relaxed_tol = {np.dtype(np.float16): 2e-1, + np.dtype(np.float32): 1e-3, + np.dtype(np.float64): 1e-5, + np.dtype(np.uint8): 0, + np.dtype(np.int32): 0, + np.dtype(np.int64): 0} # Exclude std implementations due to points 1 and 2 above. test_pooling_dim('2D', 'max', dtype, v1_pool_op_list) # The standard and 'v1' implementations match for this case. - test_pooling_dim('2D', 'avg', dtype, combo_pool_op_list, count_include_pad=True) + test_pooling_dim('2D', 'avg', dtype, combo_pool_op_list, count_include_pad=True, + tol=relaxed_tol) # Exclude std implementations due to point 3 above. test_pooling_dim('2D', 'avg', dtype, v1_pool_op_list, count_include_pad=False) # The standard and 'v1' implementations match for this case. - test_pooling_dim('2D', 'sum', dtype, combo_pool_op_list) + test_pooling_dim('2D', 'sum', dtype, combo_pool_op_list, tol=relaxed_tol) # We can compare the standard and 'v1' max pooling implementations if we eliminate padding # (see point 2 above) and use np.float64 data so that no two random input window values are