From d814733f16d50b2dd51b6f3e7e3256b5e66c8026 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Tue, 10 Jul 2018 22:36:32 -0700 Subject: [PATCH] Added query for cuDNN BN min. epsilon. Enabled choice of BN impl. for ONNX import (#11380) --- python/mxnet/contrib/onnx/onnx2mx/_op_translations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 2b98aa08febf..61f342a9ae94 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -21,7 +21,6 @@ import numpy as np from . import _translation_utils as translation_utils from .... import symbol - # Method definitions for the callable objects mapped in the import_helper module def identity(attrs, inputs, proto_obj): @@ -209,7 +208,10 @@ def batch_norm(attrs, inputs, proto_obj): 'is_test': 'fix_gamma'}) new_attrs = translation_utils._remove_attributes(new_attrs, ['spatial', 'consumed_inputs']) - new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': 1}) + # Disable cuDNN BN only if epsilon from model is < than minimum cuDNN eps (1e-5) + cudnn_min_eps = 1e-5 + cudnn_off = 0 if attrs.get('epsilon', cudnn_min_eps) >= cudnn_min_eps else 1 + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': cudnn_off}) # in test mode "fix_gamma" should be unset. new_attrs['fix_gamma'] = not attrs.get('is_test', 1)