diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4dd858a51c4b..5837069a8c32 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2317,6 +2317,11 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape, mx_uint ndim, int dtype, NDArrayHandle *out); +/*! + * \brief Query cuDNN for minimum permissible epsilon for BatchNorm. If not installed, return NaN. + * \param result the minimum epsilon provided by cuDNN + */ +MXNET_DLL int MXGetCudnnBnEpsilon(float* result); #ifdef __cplusplus } diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 0fb73b3c7dda..0b0c4a8e8e35 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -680,3 +680,16 @@ def write_all_str(module_file, module_all_list): module_op_file.close() write_all_str(module_internal_file, module_internal_all) module_internal_file.close() + +def get_cudnn_epsilon(): + """Check value of CUDNN_BN_MIN_EPSILON. If the value + of epsilon in a model is less than this, cuDNN BatchNorm + should be replaced with MxNet BatchNorm since cuDNN requires + this minimum to be met. This can be helpful e.g. in ONNX import, + so that cuDNN BN is only disabled if the value in the checkpoint + is less than CUDNN_BN_MIN_EPSILON, and enabled otherwise. + If cuDNN is not installed, this call with return NAN instead of + an exception.""" + result = ctypes.c_float(0.0) + _LIB.MXGetCudnnBnEpsilon(ctypes.byref(result)) + return result.value diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 2b98aa08febf..27553cff282e 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -18,9 +18,11 @@ # coding: utf-8 """ Module for translating ONNX operators into Mxnet operatoes""" # pylint: disable=unused-argument,protected-access +import math import numpy as np from . import _translation_utils as translation_utils from .... import symbol +from ....base import get_cudnn_epsilon # Method definitions for the callable objects mapped in the import_helper module @@ -209,7 +211,9 @@ def batch_norm(attrs, inputs, proto_obj): 'is_test': 'fix_gamma'}) new_attrs = translation_utils._remove_attributes(new_attrs, ['spatial', 'consumed_inputs']) - new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': 1}) + cudnn_eps = get_cudnn_epsilon() + cudnn_off = 0 if not math.isnan(cudnn_eps) and attrs.get('epsilon', 1e-5) >= cudnn_eps else 1 + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': cudnn_off}) # in test mode "fix_gamma" should be unset. new_attrs['fix_gamma'] = not attrs.get('is_test', 1) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 467118b9921e..3936cc302a0c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -47,6 +48,10 @@ #include "../operator/custom/custom-inl.h" #include "../operator/tensor/matrix_op-inl.h" +#if MXNET_USE_CUDA && MXNET_USE_CUDNN +#include +#endif + using namespace mxnet; // Internal function to get the information @@ -1312,3 +1317,13 @@ int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *s *out = new NDArray(shared_pid, shared_id, TShape(shape, shape + ndim), dtype); API_END(); } + +int MXGetCudnnBnEpsilon(float* result) { + API_BEGIN(); +#if MXNET_USE_CUDA && MXNET_USE_CUDNN && defined(CUDNN_BN_MIN_EPSILON) + *result = static_cast(CUDNN_BN_MIN_EPSILON); +#else + *result = NAN; +#endif // MXNET_USE_CUDA && MXNET_USE_CUDNN && defined(CUDNN_BN_MIN_EPSILON) + API_END(); +} diff --git a/tests/python/gpu/test_cudnn_eps.py b/tests/python/gpu/test_cudnn_eps.py new file mode 100644 index 000000000000..5b3ea2008059 --- /dev/null +++ b/tests/python/gpu/test_cudnn_eps.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import math +import mxnet as mx +import numpy as np +import unittest + + +def test_get_cudnn_epsilon(): + eps = mx.base.get_cudnn_epsilon() + assert math.isnan(eps) or np.isclose(1e-5, mx.base.get_cudnn_epsilon(), atol=1e-10), \ + "cudnn eps is non NaN and it's not close to 1e-5" + print ("Passed") + +if __name__ == '__main__': + test_get_cudnn_epsilon()