Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Avoid underflow of lp pooling calc for dtype=float16.
Browse files Browse the repository at this point in the history
  • Loading branch information
DickJC123 committed Feb 5, 2019
1 parent 95469db commit 098bc49
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 39 deletions.
23 changes: 13 additions & 10 deletions src/operator/nn/pool.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data,
const int stride_w, const int pad_w, DType* out_data,
const bool get_avg = false,
const bool count_include_pad = true) {
using AccType = typename PoolingTypes<DType>::AccType;
CUDA_KERNEL_LOOP(index, nthreads) {
const bool nwc_layout = layout == mshadow::kNWC;
const int idx = nwc_layout ? (index / channels) : index;
Expand All @@ -241,14 +242,14 @@ __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data,
if (get_avg && !count_include_pad) {
pool_size = (wend - wstart);
}
DType sum = 0;
AccType sum = 0;
const DType* out_slice = nwc_layout ? in_data + n * channels * width + c
: in_data + (n * channels + c) * width;
const int multiplier = nwc_layout ? channels : 1;
for (int w = wstart; w < wend; ++w) {
sum += a_pow_p<DType, p>::Map(out_slice[w * multiplier]) / pool_size;
sum += a_pow_p<AccType, p>::Map(out_slice[w * multiplier]) / pool_size;
}
out_data[index] = a_root_p<DType, p>::Map(sum);
out_data[index] = a_root_p<AccType, p>::Map(sum);
}
}

Expand All @@ -265,6 +266,7 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data,
const int pad_h, const int pad_w, DType* out_data,
const bool get_avg = false,
const bool count_include_pad = true) {
using AccType = typename PoolingTypes<DType>::AccType;
CUDA_KERNEL_LOOP(index, nthreads) {
const bool nhwc_layout = layout == mshadow::kNHWC;
const int idx = nhwc_layout ? (index / channels) : index;
Expand All @@ -285,16 +287,16 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data,
if (get_avg && !count_include_pad) {
pool_size = (hend - hstart) * (wend - wstart);
}
DType sum = 0;
AccType sum = 0;
const DType* out_slice = nhwc_layout ? in_data + n * channels * height * width + c
: in_data + (n * channels + c) * height * width;
const int multiplier = nhwc_layout ? channels : 1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
sum += a_pow_p<DType, p>::Map(out_slice[(h * width + w) * multiplier]) / pool_size;
sum += a_pow_p<AccType, p>::Map(out_slice[(h * width + w) * multiplier]) / pool_size;
}
}
out_data[index] = a_root_p<DType, p>::Map(sum);
out_data[index] = a_root_p<AccType, p>::Map(sum);
}
}

Expand All @@ -312,6 +314,7 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data,
const int pad_d, const int pad_h, const int pad_w,
DType* out_data, const bool get_avg = false,
const bool count_include_pad = true) {
using AccType = typename PoolingTypes<DType>::AccType;
CUDA_KERNEL_LOOP(index, nthreads) {
const bool ndhwc_layout = layout == mshadow::kNDHWC;
const int idx = ndhwc_layout ? (index / channels) : index;
Expand All @@ -337,21 +340,21 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data,
if (get_avg && !count_include_pad) {
pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
}
DType sum = 0;
AccType sum = 0;
const DType* out_slice = ndhwc_layout ? in_data + n * channels * depth * height * width + c
: in_data + (n * channels + c) * depth * height * width;
const int multiplier = ndhwc_layout ? channels : 1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
sum += a_pow_p<DType, p>::Map(out_slice[((d * height + h) * width + w) *
sum += a_pow_p<AccType, p>::Map(out_slice[((d * height + h) * width + w) *
multiplier]) / pool_size;
}
}
}
out_data[index] = (pool_size == 0) ?
DType(nanf("")) :
a_root_p<DType, p>::Map(sum);
AccType(nanf("")) :
a_root_p<AccType, p>::Map(sum);
}
}

Expand Down
46 changes: 26 additions & 20 deletions src/operator/nn/pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ inline void pool_sum_1d_ncw_cpu(const DType *in_data, const TShape &ishape, cons
const TShape &kernel, const TShape &pad, const TShape &stride,
DType *out_data,
const bool get_avg = false, const bool count_include_pad = true) {
using AccType = typename PoolingTypes<DType>::AccType;
const int width = ishape[2];
const int pooled_width = oshape[2];
const int kernel_w = kernel[0];
Expand All @@ -379,11 +380,11 @@ inline void pool_sum_1d_ncw_cpu(const DType *in_data, const TShape &ishape, cons
if (get_avg && !count_include_pad) {
pool_size = (wend - wstart);
}
DType sum = 0;
AccType sum = 0;
for (int w = wstart; w < wend; ++w) {
sum += a_pow_p<DType, p>::Map(in_data[w]) / pool_size;
sum += a_pow_p<AccType, p>::Map(in_data[w]) / pool_size;
}
out_data[pw] = a_root_p<DType, p>::Map(sum);
out_data[pw] = a_root_p<AccType, p>::Map(sum);
}
in_data += in_data_offset;
out_data += out_data_offset;
Expand All @@ -400,6 +401,7 @@ inline void pool_sum_1d_nwc_cpu(const DType* in_data, const TShape& ishape, cons
const TShape& kernel, const TShape& pad, const TShape& stride,
DType* out_data,
const bool get_avg = false, const bool count_include_pad = true) {
using AccType = typename PoolingTypes<DType>::AccType;
const int width = ishape[1];
const int pooled_width = oshape[1];
const int kernel_w = kernel[0];
Expand All @@ -408,7 +410,7 @@ inline void pool_sum_1d_nwc_cpu(const DType* in_data, const TShape& ishape, cons
const int features = oshape[2];
const index_t in_data_offset = ishape[1] * features;
const index_t out_data_offset = oshape[1] * features;
std::vector<DType> sums(features);
std::vector<AccType> sums(features);
for (index_t n = 0; n < oshape[0]; ++n) {
for (int pw = 0; pw < pooled_width; ++pw) {
int wstart = pw * stride_w - pad_w;
Expand All @@ -422,11 +424,11 @@ inline void pool_sum_1d_nwc_cpu(const DType* in_data, const TShape& ishape, cons
std::fill(sums.begin(), sums.end(), 0);
for (int w = wstart; w < wend; ++w) {
for (index_t c = 0; c < features; ++c) {
sums[c] += a_pow_p<DType, p>::Map(in_data[w * features + c]) / pool_size;
sums[c] += a_pow_p<AccType, p>::Map(in_data[w * features + c]) / pool_size;
}
}
for (index_t c = 0; c < features; ++c)
out_data[pw * features + c] = a_root_p<DType, p>::Map(sums[c]);
out_data[pw * features + c] = a_root_p<AccType, p>::Map(sums[c]);
}
in_data += in_data_offset;
out_data += out_data_offset;
Expand All @@ -442,6 +444,7 @@ inline void pool_sum_2d_nchw_cpu(const DType *in_data, const TShape &ishape, con
const TShape &kernel, const TShape &pad, const TShape &stride,
DType *out_data,
const bool get_avg = false, const bool count_include_pad = true) {
using AccType = typename PoolingTypes<DType>::AccType;
const int height = ishape[2], width = ishape[3];
const int pooled_height = oshape[2], pooled_width = oshape[3];
const int kernel_h = kernel[0], kernel_w = kernel[1];
Expand All @@ -465,13 +468,13 @@ inline void pool_sum_2d_nchw_cpu(const DType *in_data, const TShape &ishape, con
if (get_avg && !count_include_pad) {
pool_size = (hend - hstart) * (wend - wstart);
}
DType sum = 0;
AccType sum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
sum += a_pow_p<DType, p>::Map(in_data[h*width+w]) / pool_size;
sum += a_pow_p<AccType, p>::Map(in_data[h*width+w]) / pool_size;
}
}
out_data[ph*pooled_width+pw] = a_root_p<DType, p>::Map(sum);
out_data[ph*pooled_width+pw] = a_root_p<AccType, p>::Map(sum);
}
}
in_data += in_data_offset;
Expand All @@ -489,6 +492,7 @@ inline void pool_sum_2d_nhwc_cpu(const DType* in_data, const TShape& ishape, con
const TShape& kernel, const TShape& pad, const TShape& stride,
DType* out_data,
const bool get_avg = false, const bool count_include_pad = true) {
using AccType = typename PoolingTypes<DType>::AccType;
const int height = ishape[1], width = ishape[2];
const int pooled_height = oshape[1], pooled_width = oshape[2];
const int kernel_h = kernel[0], kernel_w = kernel[1];
Expand All @@ -497,7 +501,7 @@ inline void pool_sum_2d_nhwc_cpu(const DType* in_data, const TShape& ishape, con
const int features = oshape[3];
const index_t in_data_offset = ishape[1] * ishape[2] * features;
const index_t out_data_offset = oshape[1] * oshape[2] * features;
std::vector<DType> sums(features);
std::vector<AccType> sums(features);
for (index_t n = 0; n < oshape[0]; ++n) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
Expand All @@ -519,12 +523,12 @@ inline void pool_sum_2d_nhwc_cpu(const DType* in_data, const TShape& ishape, con
for (int w = wstart; w < wend; ++w) {
const int in_index = h * width + w;
for (index_t c = 0; c < features; ++c) {
sums[c] += a_pow_p<DType, p>::Map(in_data[in_index * features + c]) / pool_size;
sums[c] += a_pow_p<AccType, p>::Map(in_data[in_index * features + c]) / pool_size;
}
}
}
for (index_t c = 0; c < features; ++c)
out_data[pool_index * features + c] = a_root_p<DType, p>::Map(sums[c]);
out_data[pool_index * features + c] = a_root_p<AccType, p>::Map(sums[c]);
}
}
in_data += in_data_offset;
Expand All @@ -541,6 +545,7 @@ inline void pool_sum_3d_ncdhw_cpu(const DType *in_data, const TShape &ishape, co
const TShape &kernel, const TShape &pad, const TShape &stride,
DType *out_data,
const bool get_avg = false, const bool count_include_pad = true) {
using AccType = typename PoolingTypes<DType>::AccType;
const int depth = ishape[2], height = ishape[3], width = ishape[4];
const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4];
const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2];
Expand Down Expand Up @@ -569,17 +574,17 @@ inline void pool_sum_3d_ncdhw_cpu(const DType *in_data, const TShape &ishape, co
if (get_avg && !count_include_pad) {
pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
}
DType sum = 0;
AccType sum = 0;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
sum += a_pow_p<DType, p>::Map(in_data[(d*height+h)*width+w]) / pool_size;
sum += a_pow_p<AccType, p>::Map(in_data[(d*height+h)*width+w]) / pool_size;
}
}
}
out_data[(pd*pooled_height+ph)*pooled_width+pw] = (pool_size == 0) ?
DType(nanf("")) :
a_root_p<DType, p>::Map(sum);
AccType(nanf("")) :
a_root_p<AccType, p>::Map(sum);
}
}
}
Expand All @@ -598,6 +603,7 @@ inline void pool_sum_3d_ndhwc_cpu(const DType* in_data, const TShape& ishape, co
const TShape& kernel, const TShape& pad, const TShape& stride,
DType* out_data,
const bool get_avg = false, const bool count_include_pad = true) {
using AccType = typename PoolingTypes<DType>::AccType;
const int depth = ishape[1], height = ishape[2], width = ishape[3];
const int pooled_depth = oshape[1], pooled_height = oshape[2], pooled_width = oshape[3];
const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2];
Expand All @@ -606,7 +612,7 @@ inline void pool_sum_3d_ndhwc_cpu(const DType* in_data, const TShape& ishape, co
const int features = oshape[4];
const index_t in_data_offset = ishape[1] * ishape[2] * ishape[3] * features;
const index_t out_data_offset = oshape[1] * oshape[2] * oshape[3] * features;
std::vector<DType> sums(features);
std::vector<AccType> sums(features);
for (index_t n = 0; n < oshape[0]; ++n) {
for (int pd = 0; pd < pooled_depth; ++pd) {
for (int ph = 0; ph < pooled_height; ++ph) {
Expand Down Expand Up @@ -634,15 +640,15 @@ inline void pool_sum_3d_ndhwc_cpu(const DType* in_data, const TShape& ishape, co
for (int w = wstart; w < wend; ++w) {
const int in_index = (d * height + h) * width + w;
for (index_t c = 0; c < features; ++c) {
sums[c] += a_pow_p<DType, p>::Map(in_data[in_index * features + c]) / pool_size;
sums[c] += a_pow_p<AccType, p>::Map(in_data[in_index * features + c]) / pool_size;
}
}
}
}
for (index_t c = 0; c < features; ++c)
out_data[pool_index * features + c] = (pool_size == 0) ?
DType(nanf("")) :
a_root_p<DType, p>::Map(sums[c]);
AccType(nanf("")) :
a_root_p<AccType, p>::Map(sums[c]);
}
}
}
Expand Down
20 changes: 16 additions & 4 deletions src/operator/nn/pool_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@
namespace mxnet {
namespace op {

// Define an accumulator type AccType to permit float16-I/O lp pooling to avoid underflow.
template<typename DType>
struct PoolingTypes {
typedef DType AccType;
};

template<>
struct PoolingTypes<mshadow::half::half_t> {
typedef float AccType;
};

template<typename DType, int p>
struct a_pow_p {
static MSHADOW_XINLINE DType Map(const DType a) {
Expand Down Expand Up @@ -98,16 +109,17 @@ struct lp_grad<DType, 1> {
template<typename DType>
struct lp_grad<DType, 2> {
static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) {
// Avoid nan result if both grad and out_data are 0.
return (grad == DType(0.0)) ? DType(0.0) : grad * in_data / out_data;
// Avoid inf, if out_data has underflowed to 0 for a non-zero input, or nan if grad is also 0.
return (out_data == DType(0.0)) ? DType(0.0) : grad * (in_data / out_data);
}
};

template<typename DType>
struct lp_grad<DType, 3> {
static MSHADOW_XINLINE DType Map(const DType grad, const DType in_data, const DType out_data) {
// Avoid nan result if both grad and out_data are 0.
return (grad == DType(0.0)) ? DType(0.0) : grad * in_data * in_data / (out_data * out_data);
// Avoid inf, if out_data has underflowed to 0 for a non-zero input, or nan if grad is also 0.
DType in_out_ratio = in_data / out_data;
return (out_data == DType(0.0)) ? DType(0.0) : grad * in_out_ratio * in_out_ratio;
}
};

Expand Down
23 changes: 18 additions & 5 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,8 @@ def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, str

check_consistency(sym_list, ctx_list, equal_nan=(not count_include_pad), tol=tol)

def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_include_pad=True):
def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_include_pad=True,
tol=None):
if dim == '1D':
data = (3, 3, 10)
kernel = (4,)
Expand Down Expand Up @@ -966,7 +967,7 @@ def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_inclu
data=data, kernel=kernel, pad=pad, stride=stride,
pool_type=pool_type, pooling_convention=pooling_convention,
global_pool=False, p_value=p_value,
count_include_pad=count_include_pad, dtype=dtype)
count_include_pad=count_include_pad, tol=tol, dtype=dtype)
except:
print('pool_op_list = {}'.format(pool_op_list))
print('kernel={}, pad={}, stride={}'.format(kernel, pad, stride))
Expand All @@ -980,7 +981,7 @@ def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_inclu
test_pooling_versions_helper(pool_op_list=pool_op_list,
data=data, kernel=kernel, pad=None, stride=None,
pool_type=pool_type, global_pool=True, p_value=p_value,
count_include_pad=count_include_pad, dtype=dtype)
count_include_pad=count_include_pad, tol=tol, dtype=dtype)

# The various implementations of the standard pooling operator
std_pool_op_list = ['pool_cpu', 'pool_transposed_cpu',
Expand Down Expand Up @@ -1017,15 +1018,27 @@ def test_pooling_dim(dim, pool_type, dtype, pool_op_list, p_value=2, count_inclu
# only input values should appear in the output.
# 3. In avg pooling, the 'v1' operator divides the sum by the same window size factor,
# even at the edges, and so does not support count_include_pad = False.
# 4. The float16 'v1' pooling operator performs forward sums and averages in
# float16, whereas the std operators perform those calculations in float32, so
# greater float16 tolerances are needed when comparing across implementations.

# Double the float16 tol when comparing v1 and non-v1 implemenations, per note 4 above.
relaxed_tol = {np.dtype(np.float16): 2e-1,
np.dtype(np.float32): 1e-3,
np.dtype(np.float64): 1e-5,
np.dtype(np.uint8): 0,
np.dtype(np.int32): 0,
np.dtype(np.int64): 0}

# Exclude std implementations due to points 1 and 2 above.
test_pooling_dim('2D', 'max', dtype, v1_pool_op_list)
# The standard and 'v1' implementations match for this case.
test_pooling_dim('2D', 'avg', dtype, combo_pool_op_list, count_include_pad=True)
test_pooling_dim('2D', 'avg', dtype, combo_pool_op_list, count_include_pad=True,
tol=relaxed_tol)
# Exclude std implementations due to point 3 above.
test_pooling_dim('2D', 'avg', dtype, v1_pool_op_list, count_include_pad=False)
# The standard and 'v1' implementations match for this case.
test_pooling_dim('2D', 'sum', dtype, combo_pool_op_list)
test_pooling_dim('2D', 'sum', dtype, combo_pool_op_list, tol=relaxed_tol)

# We can compare the standard and 'v1' max pooling implementations if we eliminate padding
# (see point 2 above) and use np.float64 data so that no two random input window values are
Expand Down

0 comments on commit 098bc49

Please sign in to comment.