diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 2b98aa08febf..61f342a9ae94 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -21,7 +21,6 @@ import numpy as np from . import _translation_utils as translation_utils from .... import symbol - # Method definitions for the callable objects mapped in the import_helper module def identity(attrs, inputs, proto_obj): @@ -209,7 +208,10 @@ def batch_norm(attrs, inputs, proto_obj): 'is_test': 'fix_gamma'}) new_attrs = translation_utils._remove_attributes(new_attrs, ['spatial', 'consumed_inputs']) - new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': 1}) + # Disable cuDNN BN only if epsilon from model is < than minimum cuDNN eps (1e-5) + cudnn_min_eps = 1e-5 + cudnn_off = 0 if attrs.get('epsilon', cudnn_min_eps) >= cudnn_min_eps else 1 + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': cudnn_off}) # in test mode "fix_gamma" should be unset. new_attrs['fix_gamma'] = not attrs.get('is_test', 1)