Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Dangling outputs and dtype != float32: Gradient computation fails #8799

Closed
mseeger opened this issue Nov 23, 2017 · 5 comments
Closed

Dangling outputs and dtype != float32: Gradient computation fails #8799

mseeger opened this issue Nov 23, 2017 · 5 comments

Comments

@mseeger
Copy link
Contributor

mseeger commented Nov 23, 2017

Hello,

please see the complete example below for full details.
The problem is that if in the course of a computation, certain outputs of operators are not used downstream (I call this "dangling outputs"), then gradient computation fails IFF dtype != float32.

I suppose this is because somewhere deep inside, output gradients constant 0 are created for these outputs, but this is done in float32 instead of the type of the output. I think this should be really easy to fix for somebody who knows where this happens (I don't!).

You could argue that one should not ever produce dangling outputs. But at least for some operators, this is bound to happen. And in any case, things should not work for float32 and then fail for other dtype, and the user is left with a totally obscure error message (see below).

The way around this problem is to make sure there are no dangling outputs, see example below. I'd argue that this is highly obscure, and users should not have to do this.

Here the complete example. Note that this is not specific to linalg.gelqf. I am just using it as an example of an operator with >1 outputs. You can use any other op with >1 output and will get the same error (you can also use an op with 1 output, and then not use it, even though this is arguably pretty dumb).

--- Code ---

import mxnet as mx
import numpy as np
from mxnet import autograd

def crit_func(a):
q, l = mx.nd.linalg.gelqf(a)
# Could as well write _, l = mx.nd.linalg.gelqf(a)
# This circumvents the issue, but is highly obscure:
#bogus = mx.nd.BlockGrad(q) * 0.0
#return mx.nd.sum(l) + bogus
return mx.nd.sum(l)

ctx = mx.cpu()

dtype = np.float32
a32 = mx.nd.random.normal(shape=(2, 3), ctx=ctx, dtype=dtype)
a32.attach_grad()
with autograd.record():
crit32 = crit_func(a32)
print 'crit32 = %f' % crit32.asscalar()
crit32.backward()
print 'grad32\n', a32.grad.asnumpy()

dtype = np.float64
a64 = mx.nd.Cast(a32, dtype=dtype)
a64.attach_grad()
with autograd.record():
crit64 = crit_func(a64)
print 'crit64 = %f' % crit64.asscalar()
crit64.backward()
print 'grad64\n', a64.grad.asnumpy()

--- Output ---

crit32 = -0.633861
grad32
[[-0.55487823 0.57657677 0.63104177]
[-1.29575324 0.44739273 0.3476553 ]]
crit64 = -0.633862


MXNetError Traceback (most recent call last)
in ()
24 crit64 = crit_func(a64)
25 print 'crit64 = %f' % crit64.asscalar()
---> 26 crit64.backward()
27 print 'grad64\n', a64.grad.asnumpy()

/Users/matthis/MLBerlinHomes/matthis/virtenvs/linalg_ops/linalg_ops_env/lib/python2.7/site-packages/mxnet-0.12.1-py2.7.egg/mxnet/ndarray/ndarray.pyc in backward(self, out_grad, retain_graph, train_mode)
1763 ctypes.c_int(train_mode),
1764 ctypes.c_void_p(0),
-> 1765 ctypes.c_void_p(0)))
1766
1767 def tostype(self, stype):

/Users/matthis/MLBerlinHomes/matthis/virtenvs/linalg_ops/linalg_ops_env/lib/python2.7/site-packages/mxnet-0.12.1-py2.7.egg/mxnet/base.pyc in check_call(ret)
144 """
145 if ret != 0:
--> 146 raise MXNetError(py_str(_LIB.MXGetLastError()))
147
148 if sys.version_info[0] < 3:

MXNetError: [14:09:12] include/mxnet/./tensor_blob.h:216: Check failed: mshadow::DataType::kFlag == type_flag_ TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 1

Stack trace returned 10 entries:
[bt] (0) 0 libmxnet.so 0x000000010c8d8378 _ZN4dmlc15LogMessageFatalD2Ev + 40
[bt] (1) 1 libmxnet.so 0x000000010c9046d3 _ZNK5mxnet5TBlob4dptrIdEEPT_v + 211
[bt] (2) 2 libmxnet.so 0x000000010c904041 _ZNK5mxnet5TBlob14get_with_shapeIN7mshadow3cpuELi3EdEENS2_6TensorIT_XT0_ET1_EERKNS2_5ShapeIXT0_EEEPNS2_6StreamIS5_EE + 817
[bt] (3) 3 libmxnet.so 0x000000010d719ae7 _ZNK5mxnet5TBlob8FlatToKDIN7mshadow3cpuELi3EdEENS2_6TensorIT_XT0_ET1_EEPNS2_6StreamIS5_EE + 839
[bt] (4) 4 libmxnet.so 0x000000010d713248 ZN5mxnet2op12LaOpBackwardIN7mshadow3cpuELi2ELi2ELi4ELi1ENS0_14gelqf_backwardEEEvRKN4nnvm9NodeAttrsERKNS_9OpContextERKNSt3__16vectorINS_5TBlobENSC_9allocatorISE_EEEERKNSD_INS_9OpReqTypeENSF_ISK_EEEESJ + 1464
[bt] (5) 5 libmxnet.so 0x000000010d8aed74 ZZN5mxnet10imperative12PushFComputeERKNSt3__18functionIFvRKN4nnvm9NodeAttrsERKNS_9OpContextERKNS1_6vectorINS_5TBlobENS1_9allocatorISB_EEEERKNSA_INS_9OpReqTypeENSC_ISH_EEEESG_EEEPKNS3_2OpES6_RKNS_7ContextERKNSA_IPNS_6engine3VarENSC_ISY_EEEES12_RKNSA_INS_8ResourceENSC_IS13_EEEERKNSA_IPNS_7NDArrayENSC_IS19_EEEES1D_RKNSA_IjNSC_IjEEEESL_ENKUlNS_10RunContextENSW_18CallbackOnCompleteEE_clES1I_S1J + 532
[bt] (6) 6 libmxnet.so 0x000000010d8ae0f7 ZNSt3__110__function6__funcIZN5mxnet10imperative12PushFComputeERKNS_8functionIFvRKN4nnvm9NodeAttrsERKNS2_9OpContextERKNS_6vectorINS2_5TBlobENS_9allocatorISD_EEEERKNSC_INS2_9OpReqTypeENSE_ISJ_EEEESI_EEEPKNS5_2OpES8_RKNS2_7ContextERKNSC_IPNS2_6engine3VarENSE_IS10_EEEES14_RKNSC_INS2_8ResourceENSE_IS15_EEEERKNSC_IPNS2_7NDArrayENSE_IS1B_EEEES1F_RKNSC_IjNSE_IjEEEESN_EUlNS2_10RunContextENSY_18CallbackOnCompleteEE_NSE_IS1M_EEFvS1K_S1L_EEclEOS1K_OS1L + 55
[bt] (7) 7 libmxnet.so 0x000000010d839cf2 _ZN5mxnet6engine11NaiveEngine9PushAsyncENSt3__18functionIFvNS_10RunContextENS0_18CallbackOnCompleteEEEENS_7ContextERKNS2_6vectorIPNS0_3VarENS2_9allocatorISB_EEEESG_NS_10FnPropertyEiPKc + 194
[bt] (8) 8 libmxnet.so 0x000000010d89d3c2 ZN5mxnet10imperative12PushFComputeERKNSt3__18functionIFvRKN4nnvm9NodeAttrsERKNS_9OpContextERKNS1_6vectorINS_5TBlobENS1_9allocatorISB_EEEERKNSA_INS_9OpReqTypeENSC_ISH_EEEESG_EEEPKNS3_2OpES6_RKNS_7ContextERKNSA_IPNS_6engine3VarENSC_ISY_EEEES12_RKNSA_INS_8ResourceENSC_IS13_EEEERKNSA_IPNS_7NDArrayENSC_IS19_EEEES1D_RKNSA_IjNSC_IjEEEESL + 1186
[bt] (9) 9 libmxnet.so 0x000000010d89b501 _ZN5mxnet10Imperative8InvokeOpERKNS_7ContextERKN4nnvm9NodeAttrsERKNSt3__16vectorIPNS_7NDArrayENS8_9allocatorISB_EEEESG_RKNS9_INS_9OpReqTypeENSC_ISH_EEEENS_12DispatchModeENS_10OpStatePtrE + 1137

@zhiz21
Copy link

zhiz21 commented Mar 10, 2018

proposed labels: "Python", "Bug", "Operator", "Question"

@reminisce
Copy link
Contributor

Minimum reproducible script:

import mxnet as mx
from mxnet import autograd


data = mx.nd.arange(16, dtype='float64').reshape((4, 4))
data.attach_grad()

with autograd.record():
    y = mx.nd.split(data, axis=0, num_outputs=2)
y[0].backward()
print(data.grad)

@apeforest
Copy link
Contributor

I am working on this issue together with #9067

@apeforest
Copy link
Contributor

@mseeger This issue has been fixed. Please verify.

@apeforest
Copy link
Contributor

@sandeep-krishnamurthy This issue has been resolved. Please close it. Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

5 participants