diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index eb33f9b5217e..935bd9ab1823 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -435,6 +435,90 @@ def test_nag(): continue compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) +#SGLD +class PySGLD(mx.optimizer.Optimizer): + """python reference implementation of SGLD""" + + def __init__(self, **kwargs): + super(PySGLD, self).__init__(**kwargs) + + def create_state(self, index, weight): + return None + + def update(self, index, weight, grad, state): + assert(isinstance(weight, mx.nd.NDArray)) + assert(isinstance(grad, mx.nd.NDArray)) + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + + grad = grad * self.rescale_grad + if self.clip_gradient is not None: + grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient) + weight[:] += - lr/2 * (grad + wd * weight) + mx.random.normal(0, math.sqrt(lr), shape=weight.shape, + dtype=weight.dtype, ctx=weight.context) + + + +@with_seed() +def test_sgld(): + opt1 = PySGLD + opt2 = mx.optimizer.SGLD + shape = (3, 4, 5) + ns_options = [1234, 42] + + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] + mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}] + + + def compare_optimizer_noise_seeded(opt1, opt2, shape, dtype, noise_seed, + w_stype='default', g_stype='default', + rtol=1e-4, atol=1e-5, compare_states=True): + """Compare opt1 and opt2 with the added functionality that the seed for generating random noise + in the SGLD optimizer update is set so that the same noise is used in opt1 and opt2. + + """ + if w_stype == 'default': + w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) + w1 = w2.copyto(default_context()) + elif w_stype == 'row_sparse' or w_stype == 'csr': + w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype) + w1 = w2.copyto(default_context()).tostype('default') + else: + raise Exception("type not supported yet") + if g_stype == 'default': + g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) + g1 = g2.copyto(default_context()) + elif g_stype == 'row_sparse' or g_stype == 'csr': + g2 = rand_ndarray(shape, g_stype, dtype=dtype) + g1 = g2.copyto(default_context()).tostype('default') + else: + raise Exception("type not supported yet") + + state1 = opt1.create_state_multi_precision(0, w1) + state2 = opt2.create_state_multi_precision(0, w2) + if compare_states: + compare_ndarray_tuple(state1, state2) + + # set seed for Gaussian noise replication + mx.random.seed(noise_seed) + opt1.update_multi_precision(0, w1, g1, state1) + mx.random.seed(noise_seed) + opt2.update_multi_precision(0, w2, g2, state2) + if compare_states: + compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol) + assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol) + + for seed in ns_options: + for dtype in [np.float16, np.float32, np.float64]: + for params in itertools.product(cg_options, wd_options, mp_options): + kwarg = {k: v for param in params for k, v in param.items()} + if (dtype == np.float16 and ('multi_precision' not in kwarg or + not kwarg['multi_precision'])): + continue + compare_optimizer_noise_seeded(opt1(**kwarg), opt2(**kwarg), shape, dtype, seed) + # FTML