Skip to content

Commit

Permalink
[v1.x] Backport Unittest tolerance handling improvements (apache#18694)…
Browse files Browse the repository at this point in the history
…. Also test seeding (apache#18762). (apache#19148)

* Add sm arch 80 to Makefile

* Unittest tolerance handling improvements (apache#18694)

* Add sm arch 80 to Makefile

* Add TF32 to cuBLAS GEMMs

Signed-off-by: Serge Panev <[email protected]>

* Add CUDA version guards

Signed-off-by: Serge Panev <[email protected]>

* Remove useless TF32 for double and old CUDA version

Signed-off-by: Serge Panev <[email protected]>

* Factorize VERSION_ADJUSTED_TF32_MATH

Signed-off-by: Serge Panev <[email protected]>

* Add TF32 considerations to test_util.py:check_consistency()

* Bypass test_gluon_gpu.py:test_large_models if gmem >32GB

* Default tols in assert_almost_equal() now a function of dtype and ctx

* Expand types listed by default_tols()

* Fix pylint

* All with_seed() tests to waitall in teardown

* Elevate MXNET_TEST_SEED logging to WARNING

* Revert test_gluon_gpu.py:test_rnn_layer to default tols

* Fix test_gluon_model_zoo_gpu.py::test_inference and test_operator_gpy.py::test_np_linalg_{solve,tensorinv}

* test_numpy_interoperability.py to not fix seed for rest of CI

* Further fix to test_np_linalg_tensorinv

* Fix test_gluon_data.py:test_dataloader_context when run on 1-GPU system.

* Fix test_operator_gpu.py::test_embedding_with_type

* Fix test_operator_gpu.py::{test_*convolution_large_c,test_np_linalg_tensorsolve}

* Remove unneeded print() from test_numpy_interoperability.py

* Unify tol handling of check_consistency() and assert_almost_equal().  Test tweeks.

* Add tol handling of assert_almost_equal() with number args

* Add tol handling of bool comparisons

* Fix test_numpy_op.py::test_np_random_rayleigh

* Fix test_operator_gpu.py::test_batchnorm_with_type

* Fix test_gluon.py::test_sync_batchnorm in cpu selftest

* Improve unittest failure reporting

* Add to robustness of test_operator_gpu.py::test_embedding_with_type

* Check_consistency() to use equal backward gradients for increased test robustness

* Fix test_operator_gpu.py::test_{fully_connected,gemm}.  Add default_numeric_eps().

* test_utils.py fix for numeric gradient calc

* Reinstate rtol=1e-2 for test_operator.py::test_order

* Remove auto-cast of check_consistency() input data to least precise dtype (not needed)

* Fix test_operator.py::test_{reciprocol,cbrt,rcbrt}_op

* Expand default float64 numeric_eps for test_operator_gpu.py::test_sofmin

* Fix segfault-on-error of @Retry decorator. Add test isolation.

* assert_almost_equal() to handle a,b scalars

* Fix test_operator_gpu.py::test_gluon_{mvn,mvn_v1} race

* Fix test_operator_gpu.py::test_flatten_slice_after_conv via scale

* Remove test_utils.py:almost_equal_ignore_nan()

* Fix sample vs. pop variance issue with test_numpy_op.py::test_npx_batch_norm

* Expose test_utils.py:effective_dtype() and use to fix test_operator_gpu.py::test_np_linalg_svd

* Fix true_divide int_array / int_scalar -> float_array to honor np_default_dtype

* Try test_elemwise_binary_ops serial to avoid pytest worker crash

* Fix (log_)softmax backward on empty ndarray

* Temporarily log all CI seeds to troubleshoot seed non-determinism

* Revert "Temporarily log all CI seeds to troubleshoot seed non-determinism"

This reverts commit f60eff2.

* Temp log all CI seeds to troubleshoot unwanted seed determinism

* Revert "Add sm arch 80 to Makefile"

This reverts commit f9306ce.

* Same fix of sample vs. pop variance issue, now with test_operator_gpu.py::test_batchnorm

* Revert "Temp log all CI seeds to troubleshoot unwanted seed determinism"

This reverts commit ff328ef.

* Marking test_sparse_dot_grad with garbage_expected after teardown error

* Fix flakiness of test_gluon_probability{_v1,_v2}.py::test_gluon_kl{_v1,}

* Temp skip of test_aggregate_duplication on gpu

* Add seeding to test_{numpy,}_contrib_gluon_data_vision.py.  Make created files unique.

* Add ndarray module isolation to help debug test_bbox_augmenters worker crash

* Marking test_sparse_square_sum serial after pytest worker crash

* Fix flakiness of test_gluon_probability{_v1,_v2}.py::test_half_cauchy{_v1,}

Co-authored-by: Serge Panev <[email protected]>
Co-authored-by: Bart Gawrych <[email protected]>

* Fix test_gluon_data.py:test_dataloader_context when run on 1-GPU system.

* Remove pytest decorators introduced in error

* Fix test_forward.py:test_consistency

* Fix test_numpy_op.py tests

* Improve test seeding in test_numpy_interoperablity.py (apache#18762)

* Fix test_numpy_op.py:test_np_random_{beta,chisquare}

* Reduce problem sizes with test_optimizer.py:test_multilamb

* Skip test_gluon_gpu.py:test_fused_{lstm,gpu}_layer, fix test_rnn_cells, for fp16 contexts

* Trigger CI

Co-authored-by: Serge Panev <[email protected]>
Co-authored-by: Bart Gawrych <[email protected]>
  • Loading branch information
3 people committed Sep 17, 2020
1 parent 620d058 commit ce0a518
Show file tree
Hide file tree
Showing 18 changed files with 489 additions and 322 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ endif
# be JIT-compiled by the updated driver from the included PTX.
ifeq ($(USE_CUDA), 1)
ifeq ($(CUDA_ARCH),)
KNOWN_CUDA_ARCHS := 30 35 50 52 60 61 70 75
KNOWN_CUDA_ARCHS := 30 35 50 52 60 61 70 75 80
# Run nvcc on a zero-length file to check architecture-level support.
# Create args to include SASS in the fat binary for supported levels.
CUDA_ARCH := $(foreach arch,$(KNOWN_CUDA_ARCHS), \
Expand Down
360 changes: 232 additions & 128 deletions python/mxnet/test_utils.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,14 @@ void linalg_batch_det_backward_helper(const Tensor<xpu, 3, DType>& LU,
const DType zero_det,
const mxnet::OpContext& ctx);

#ifdef __CUDACC__
#if CUDA_VERSION < 11000
#define VERSION_ADJUSTED_TF32_MATH CUBLAS_DEFAULT_MATH
#else
#define VERSION_ADJUSTED_TF32_MATH CUBLAS_TF32_TENSOR_OP_MATH
#endif
#endif // __CUDACC__

#include "linalg_impl.h"

#endif // MXNET_OPERATOR_LINALG_H_
34 changes: 26 additions & 8 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,15 @@ inline void linalg_gemm<gpu, float>(const Tensor<gpu, 2, float>& A,
#else
cublasDataType_t full_datatype = CUBLAS_DATA_FULL;
#endif
auto handle = Stream<gpu>::GetBlasHandle(s);
cublasMath_t saved_math_mode = SetCublasMathMode(handle, VERSION_ADJUSTED_TF32_MATH);
CUBLAS_CALL(cublasSgemmEx(
Stream<gpu>::GetBlasHandle(s), (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
handle, (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(1), C.size(0),
(tB ? B.size(1) : B.size(0)), &alpha, B.dptr_, full_datatype, B.stride_,
A.dptr_, full_datatype, A.stride_, &beta, C.dptr_, full_datatype,
C.stride_))
C.stride_));
CUBLAS_CALL(cublasSetMathMode(handle, saved_math_mode));
}

#else
Expand All @@ -228,13 +231,16 @@ void linalg_gemm_axis<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<g
using mshadow::gpu; \
CHECK_NOTNULL(s); \
linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
auto handle = Stream<gpu>::GetBlasHandle(s); \
cublasMath_t saved_math_mode = SetCublasMathMode(handle, VERSION_ADJUSTED_TF32_MATH); \
CUBLAS_CALL(cublas##fname(handle, \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
C.size(2), C.size(0), (tB ? B.size(2) : B.size(0)), &alpha, \
B.dptr_, B.size(1)*B.stride_, B.stride_, \
A.dptr_, A.size(1)*A.stride_, A.stride_, &beta, \
C.dptr_, C.size(1)*C.stride_, C.stride_, A.size(1))) \
CUBLAS_CALL(cublasSetMathMode(handle, saved_math_mode)); \
}
LINALG_GPU_GEMM_AXIS(SgemmStridedBatched, float)
LINALG_GPU_GEMM_AXIS(DgemmStridedBatched, double)
Expand Down Expand Up @@ -342,13 +348,22 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
check_gemm(A[0], B[0], C[0], alpha, beta, tA, tB); \
using namespace mshadow::cuda; \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
auto handle = Stream<gpu>::GetBlasHandle(s); \
cublasMath_t saved_math_mode = SetCublasMathMode(handle, VERSION_ADJUSTED_TF32_MATH); \
CUBLAS_CALL(cublas##fname(handle, \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
C.size(2), C.size(1), (tB ? B.size(2) : B.size(1)), \
&alpha, B.dptr_, B.stride_, B.size(1) * B.stride_, \
A.dptr_, A.stride_, A.size(1) * A.stride_, \
&beta, C.dptr_, C.stride_, C.size(1) * C.stride_, A.size(0))) \
&alpha, \
B.dptr_, B.stride_, \
static_cast<int64_t>(B.size(1) * B.stride_), \
A.dptr_, A.stride_, \
static_cast<int64_t>(A.size(1) * A.stride_), \
&beta, \
C.dptr_, C.stride_, \
static_cast<int64_t>(C.size(1) * C.stride_), \
A.size(0))) \
CUBLAS_CALL(cublasSetMathMode(handle, saved_math_mode)); \
}

LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)
Expand All @@ -373,7 +388,7 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:

using namespace mshadow::cuda;
auto cublas_math_mode =
use_tensor_ops ? CUBLAS_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;
use_tensor_ops ? CUBLAS_TENSOR_OP_MATH : VERSION_ADJUSTED_TF32_MATH;
auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode);

// cublasGemmStridedBatchedEx is only supported for GPU with architecture
Expand Down Expand Up @@ -414,6 +429,8 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
CHECK_NOTNULL(s); \
linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
linalg_check_batch_size(A.size(2), B.size(2), C.size(2)); \
auto handle = Stream<gpu>::GetBlasHandle(s); \
cublasMath_t saved_math_mode = SetCublasMathMode(handle, VERSION_ADJUSTED_TF32_MATH); \
for (index_t i = 0; i < A.size(2); ++i) { \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
Expand All @@ -423,6 +440,7 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
A.dptr_+i*A.stride_, A.size(2) * A.stride_, A.size(1)*A.size(2)*A.stride_, &beta, \
C.dptr_+i*C.stride_, C.size(2) * C.stride_, C.size(1)*C.size(2)*C.stride_, A.size(0))) \
}\
SetCublasMathMode(handle, saved_math_mode); \
}

LINALG_GPU_BATCH_GEMM_AXIS(SgemmStridedBatched, float)
Expand Down
19 changes: 11 additions & 8 deletions src/operator/numpy/np_true_divide-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs,
});
} else {
#ifndef _WIN32
CHECK_EQ(outputs[0].type_flag_, kFloat32) << "true_divide only supports float32 output "
"when input's dtype is "
<< type_string(inputs[0].type_flag_);
MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
Kernel<op_with_req<OP, Req>, xpu>::Launch(
s, data.Size(), out.dptr<float>(), data.dptr<DType>(),
static_cast<float>(alpha));
CHECK(out.type_flag_ == mshadow::kFloat32 || out.type_flag_ == mshadow::kFloat64)
<< "true_divide only supports float32 and float64"
" output when input's dtype is "
<< type_string(inputs[0].type_flag_);
MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, ODType, {
MXNET_INT_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
Kernel<op_with_req<OP, Req>, xpu>::Launch(
s, data.Size(), out.dptr<ODType>(), data.dptr<DType>(),
static_cast<ODType>(alpha));
});
});
});
#else
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_consistency(dump=False):
ctx_list = [{'ctx': mx.gpu(0), 'data': data.shape, 'type_dict': {'data': data.dtype}},
{'ctx': mx.cpu(0), 'data': data.shape, 'type_dict': {'data': data.dtype}}]
gt = check_consistency(sym, ctx_list, arg_params=arg_params, aux_params=aux_params,
tol=1e-3, grad_req='null', raise_on_err=False, ground_truth=gt)
rtol=1e-3, atol=1e-3, grad_req='null', raise_on_err=False, ground_truth=gt)
if dump:
np.savez('data/inception-v3-dump.npz', **{n: a.asnumpy() for n, a in gt.items()})

Expand Down
16 changes: 11 additions & 5 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ def check_rnn_layer(layer):
states = layer.begin_state(16)
co, cs = layer(x, states)

# atol of 1e-6 required, as exposed by seed 2124685726
assert_almost_equal(go, co, rtol=1e-2, atol=1e-6)
assert_almost_equal(go, co)
for g, c in zip(gs, cs):
assert_almost_equal(g, c, rtol=1e-2, atol=1e-6)
assert_almost_equal(g, c)


@with_seed()
Expand All @@ -70,9 +69,9 @@ def check_rnn_layer_w_rand_inputs(layer):
states = layer.begin_state(16)
co, cs = layer(x, states)

assert_almost_equal(go, co, rtol=1e-2, atol=1e-6)
assert_almost_equal(go, co)
for g, c in zip(gs, cs):
assert_almost_equal(g, c, rtol=1e-2, atol=1e-6)
assert_almost_equal(g, c)


@with_seed()
Expand Down Expand Up @@ -481,6 +480,13 @@ def tensor_size(big_tensor_bytes):
# This in the past has given cudnnFind() trouble when it needed to allocate similar I/O's
# from the area carved out by the MXNET_GPU_MEM_POOL_RESERVE setting (by default 5%).
(free_mem_bytes, total_mem_bytes) = mx.context.gpu_memory_info(ctx.device_id)
# This test needs to be 'qualified' for use with each new larger memory size
largest_supported_total_mem_GB = 32
if (total_mem_bytes > largest_supported_total_mem_GB * 1024 * 1024 * 1024):
sys.stderr.write(
' bypassing test due to too-large global memory of size {} ... '.format(total_mem_bytes))
return

start_size = tensor_size(0.20 * total_mem_bytes)
num_trials = 10
sys.stderr.write(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_gluon_model_zoo_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_inference():
max_val = np.max(np.abs(cpu_out.asnumpy()))
gpu_max_val = np.max(np.abs(gpu_out.asnumpy()))
eprint(model_name + ": CPU " + str(max_val) + ", GPU " + str(gpu_max_val))
assert_almost_equal(cpu_out / max_val, gpu_out / gpu_max_val, rtol=1e-3, atol=1e-3)
assert_almost_equal(cpu_out / max_val, gpu_out / gpu_max_val)

def get_nn_model(name):
if "densenet" in name:
Expand Down
Loading

0 comments on commit ce0a518

Please sign in to comment.