Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added query for cuDNN BN min. epsilon. Enabled choice of BN impl. for…
Browse files Browse the repository at this point in the history
… ONNX import (#11380)
  • Loading branch information
mkolod authored and szha committed Jul 11, 2018
1 parent 43ad56c commit d814733
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import numpy as np
from . import _translation_utils as translation_utils
from .... import symbol

# Method definitions for the callable objects mapped in the import_helper module

def identity(attrs, inputs, proto_obj):
Expand Down Expand Up @@ -209,7 +208,10 @@ def batch_norm(attrs, inputs, proto_obj):
'is_test': 'fix_gamma'})
new_attrs = translation_utils._remove_attributes(new_attrs,
['spatial', 'consumed_inputs'])
new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': 1})
# Disable cuDNN BN only if epsilon from model is < than minimum cuDNN eps (1e-5)
cudnn_min_eps = 1e-5
cudnn_off = 0 if attrs.get('epsilon', cudnn_min_eps) >= cudnn_min_eps else 1
new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': cudnn_off})

# in test mode "fix_gamma" should be unset.
new_attrs['fix_gamma'] = not attrs.get('is_test', 1)
Expand Down

0 comments on commit d814733

Please sign in to comment.