Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix sample vs. pop variance issue with test_numpy_op.py::test_npx_bat…
Browse files Browse the repository at this point in the history
…ch_norm
  • Loading branch information
DickJC123 committed Jul 10, 2020
1 parent 6f599a3 commit e0a7dda
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,15 +1848,18 @@ def _test_batchnorm_impl(axis,

running_mean = running_mean * momentum + \
data_mean_flat * (1 - momentum)

m = _np.prod(shape) / shape[axis]
# cudnn uses m-1 in the denominator of its sample variance calculation, not m
sample_var_adjust = 1.0 if cudnn_off or fix_gamma else m / (m-1)
running_var = running_var * momentum + \
data_var_flat * (1 - momentum)
data_var_flat * sample_var_adjust * (1 - momentum)

W = bn_gamma.reshape(expand_shape)
dnx = ograd * W
xsm = data - data_mean
nd = 1.0 / np.sqrt(data_var + epsilon)
nx = xsm * nd
m = _np.prod(shape) / shape[axis]
dvar = (dnx * xsm).sum(axis=reduce_axis, keepdims=True,
) * (-0.5) * np.power(nd, 3)
dmean = -nd * dnx.sum(axis=reduce_axis, keepdims=True) - \
Expand Down

0 comments on commit e0a7dda

Please sign in to comment.