-
Notifications
You must be signed in to change notification settings - Fork 6.8k
F.Take Backwards - Incorrect Gradient #19817
Comments
I can confirm that this bug has been fixed on master branch. Here are the outputs from the master branch (after adopting the new Gluon interface) scriptimport numpy as np
import mxnet as mx
from mxnet.gluon.nn import HybridBlock, Conv1D, HybridSequential, HybridLambda, Dense
from mxnet import autograd, nd
from mxnet.gluon.loss import L2Loss
def print_grads(model, ctx=mx.cpu()):
pd = model.collect_params()
total_grad_l2 = 0
total_grad_l1 = 0
total_grad_linf = 0
for p in pd:
try:
g = pd[p].grad(ctx) / N
g2 = (g**2).sum().as_in_context(mx.cpu()).asscalar()
g1 = g.abs().sum().as_in_context(mx.cpu()).asscalar()
ginf = g.max().as_in_context(mx.cpu()).asscalar()
total_grad_linf = max(total_grad_linf, ginf)
total_grad_l2 += g2
total_grad_l1 += g1
print(f"||g_param||_2: {g2**0.5:.2E} | Param: {p}")
except Exception:
pass
grad_info = f"""
-------------------------------------------
------- Grad Info
* ||g||_2: {total_grad_l2**0.5:.2E}
* ||g||_1: {total_grad_l1:.2E}
* ||g||_inf: {total_grad_linf:.2E}
"""
print(grad_info)
def run_model(model, loss, X, Y, num_iters=1):
for i in range(num_iters):
with autograd.record():
Y_hat = model(X)
ll = loss(Y_hat, Y)
ll = ll.sum()
ll.backward()
print_grads(model)
return Y_hat
def conv_layer(atrous_rates, num_channels):
convs = HybridSequential()
convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
for rate in atrous_rates:
convs.add(Conv1D(num_channels, 3, padding=rate, dilation=rate, activation='tanh'))
convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
return convs
class Model(HybridBlock):
"""
Model takes tensors of shape N x T x C and produces predictions with shape N x T
"""
def __init__(self, conv_units, atrous_rates, use_take=False, **kwargs):
super().__init__()
self.use_take = use_take
self.convs = conv_layer(atrous_rates, conv_units)
self.dense_out = Dense(1, flatten=False, activation='tanh')
def hybrid_forward(self, F, X):
X1 = X
X2 = self.convs(X1)
if self.use_take:
X3 = F.take(X2, nd.array([1, 2, 3]), axis=-1)
else:
X3 = F.slice_axis(X2, begin=1, end=4, axis=-1)
X4 = self.dense_out(X3)
X4 = F.squeeze(X4, axis=-1)
return X4
if __name__ == "__main__":
N = 30
T = 20
C = 8
conv_units = 5
atrous_rates = [1, 2, 4]
np.random.seed(1234)
X = np.random.normal(size=(N, T, C))
Y = np.random.normal(size=(N, T))
X, Y = nd.array(X), nd.array(Y)
# Using F.take
mx.random.seed(12354)
model = Model(conv_units, atrous_rates, use_take=True)
model.initialize()
loss = L2Loss()
Y_hat1 = run_model(model, loss, X, Y)
# Using F.slice_axis
mx.random.seed(12354)
model2 = Model(conv_units, atrous_rates, use_take=False)
model2.initialize()
loss2 = L2Loss()
Y_hat2 = run_model(model2, loss2, X, Y)
delta = nd.abs(Y_hat1-Y_hat2).sum().asscalar()
print("==== Same outputs?")
print(f"Y_hat1 - Yhat2 = {delta:.4f}")
|
Thanks for looking into this -- do you know which commit fixed the bug? Also, do you know which upcoming release would contain the bugfix? |
Actually I think this bug appears to be non-deterministic. If I run the script a couple more times I get weird results such as the following, which happens on both v1.x and on master: script
environmentfrom commit bca8de8
Update: if I turn off mkldnn, the results are consistently different
|
Yeah I observe similar behavior on v1.x -- sometimes the grad calculation is correct, but most of the time they are different |
I think the issue should be fixed by #20166. Would we close the issue? @ceisenach @szha |
When I use the latest nightly builds, I no longer observe the bug, so it seems resolved to me. Thanks for the fix! |
Description
Backwards implementation of F.take computes incorrect gradient when used after sequence of transpose -> convolution -> transpose. any trainable parameters that receive gradients through the
F.take
operator are incorrect. Equivalent implementations using slice operators produce correct results.Other Details
I have been unable to find any other scenario when it happens (for example, if one replaces the Conv Layers in the example below with a linear layer, there is no issue with the gradient computation).
I also encounter the bug on MXNet 1.5 and 1.6 (have not tested with earlier versions).
To Reproduce
Below I provide an example of a simple model with two implementations -- one that uses
F.take
(Model A) and one that usesF.slice_axis
(Model B) instead.The script provided below instantiates both implementations with the same initial weights, computes L2Loss and prints the gradients from both models. A random seed is set so the output should be deterministic (and it is for Model B).
Steps to reproduce
python take_bug.py
)Result
F.take
are on the order of 1e28 (or in some cases are infinite). The results are non-deterministicExample output from the script I provided
It appears that there is either an OOB memory access or some values involved in the calculation are not initialized before they are used. I haven't attempted to track down the root cause.
What have you tried to solve it?
In many cases, can workaround by using one of the slice operators and concatenation instead. They do not appear to have any issues.
Environment
OS: ubuntu 18.04
Python: 3.8.5
pip: 20.2.3
mxnet: 1.7.0 (Commit Hash: 64f737c)
numpy: 1.19.2
The text was updated successfully, but these errors were encountered: