diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 69c35f85c648..69f87ae42f97 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -83,7 +83,8 @@ void DotForward_(const nnvm::NodeAttrs& attrs, (outputs[0].type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask)) << "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU"; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1) { + // VectorDot() with fp16 is not supported in mshadow. Dispatch to dot() instead. + if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1 && inputs[0].type_flag_ != kFloat16) { CHECK_NE(req[0], kAddTo) << "AddTo not yet supported"; Tensor out = outputs[0].get(s); VectorDot(out, diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index fc003b2271ef..b0c640bc6e35 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2631,47 +2631,52 @@ def test_stn_valid_sampling(): ) + target_shape)) -# @haojin2: Getting rid of fixed seed as flakiness could not be reproduced, -# tracked at https://github.com/apache/incubator-mxnet/issues/11714 @with_seed() def test_dot(): - ctx=default_context() + ctx = default_context() dtypes = ['float32', 'float64'] + ndims = [2] if ctx.device_type == 'gpu': dtypes += ['float16'] + ndims += [1] # Test normal dot. - for data_type in dtypes: - for m in range(1, 5): - for k in range(1, 5): - for n in range(1, 5): - a_npy = np.random.normal(0, 1, (m, k)) - a_npy = a_npy.astype(data_type) - b_npy = np.random.normal(0, 1, (k, n)) - b_npy = b_npy.astype(data_type) - c_npy = np.empty((m, n), dtype=data_type) - ograd_npy = np.random.normal(0, 1, (m, n)) - ograd_npy = ograd_npy.astype(data_type) - agrad_npy = np.empty((m, k), dtype=data_type) - bgrad_npy = np.empty((k, n), dtype=data_type) - c_npy[:, :] = np.dot(a_npy[:, :], b_npy[:, :]) - bgrad_npy[:, :] = np.dot(a_npy[:, :].T, ograd_npy[:, :]) - agrad_npy[:, :] = np.dot(ograd_npy[:, :], b_npy[:, :].T) - a = mx.sym.Variable('a', dtype=data_type) - b = mx.sym.Variable('b', dtype=data_type) - c = mx.sym.dot(a, b) - exe = c.simple_bind(ctx=ctx, a=a_npy.shape, b=b_npy.shape) - outputs = exe.forward(is_train=True, a=a_npy, b=b_npy) - assert_almost_equal(outputs[0].asnumpy(), c_npy, - rtol=1e-2 if data_type == 'float16' else 1e-3, - atol=1e-2 if data_type == 'float16' else 1e-3) - exe.backward(out_grads=[mx.nd.array(ograd_npy, mx.cpu()).astype(data_type)]) - assert_almost_equal(exe.grad_dict['a'].asnumpy(), agrad_npy, - rtol=1e-2 if data_type == 'float16' else 1e-3, - atol=1e-2 if data_type == 'float16' else 1e-3) - assert_almost_equal(exe.grad_dict['b'].asnumpy(), bgrad_npy, - rtol=1e-2 if data_type == 'float16' else 1e-3, - atol=1e-2 if data_type == 'float16' else 1e-3) + for ndim in ndims: + for data_type in dtypes: + for m in range(1, 5): + for k in range(1, 5): + if ndim == 1 and k != 1: + pass + for n in range(1, 5): + a_shape = (m, k) if ndim == 2 else (m,) + b_shape = (k, n) if ndim == 2 else (n,) + a_npy = np.random.normal(0, 1, (m, k)) + a_npy = a_npy.astype(data_type) + b_npy = np.random.normal(0, 1, (k, n)) + b_npy = b_npy.astype(data_type) + c_npy = np.empty((m, n), dtype=data_type) + ograd_npy = np.random.normal(0, 1, (m, n)) + ograd_npy = ograd_npy.astype(data_type) + agrad_npy = np.empty((m, k), dtype=data_type) + bgrad_npy = np.empty((k, n), dtype=data_type) + c_npy[:, :] = np.dot(a_npy[:, :], b_npy[:, :]) + bgrad_npy[:, :] = np.dot(a_npy[:, :].T, ograd_npy[:, :]) + agrad_npy[:, :] = np.dot(ograd_npy[:, :], b_npy[:, :].T) + a = mx.sym.Variable('a', dtype=data_type) + b = mx.sym.Variable('b', dtype=data_type) + c = mx.sym.dot(a, b) + exe = c.simple_bind(ctx=ctx, a=a_npy.shape, b=b_npy.shape) + outputs = exe.forward(is_train=True, a=a_npy, b=b_npy) + assert_almost_equal(outputs[0].asnumpy(), c_npy, + rtol=1e-2 if data_type == 'float16' else 1e-3, + atol=1e-2 if data_type == 'float16' else 1e-3) + exe.backward(out_grads=[mx.nd.array(ograd_npy, mx.cpu()).astype(data_type)]) + assert_almost_equal(exe.grad_dict['a'].asnumpy(), agrad_npy, + rtol=1e-2 if data_type == 'float16' else 1e-3, + atol=1e-2 if data_type == 'float16' else 1e-3) + assert_almost_equal(exe.grad_dict['b'].asnumpy(), bgrad_npy, + rtol=1e-2 if data_type == 'float16' else 1e-3, + atol=1e-2 if data_type == 'float16' else 1e-3) # Test dot with transpose flag using gradient checker. def dot_sym(data_type):