Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Ndarray.asnumpy() error with gluon dense under both GPU and CPU environment #10807

Closed
lionel92 opened this issue May 4, 2018 · 8 comments
Closed
Labels

Comments

@lionel92
Copy link

lionel92 commented May 4, 2018

Description

I try to use gluon API to test mkldnn while an error occurs when doing asnumpy() operation after the network's backward propogation. So I performed the test under gpu and native cpu environment, and the error still exists. What' more, the small size of input data will not trigger the error.
You can run the following mininum example to reproduce the error.

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.test_utils import assert_almost_equal
from common import setup_module, with_seed
import numpy as np
import random
from nose.tools import raises

#mx.Context.default_ctx = mx.Context('gpu', 0)
mx.Context.default_ctx = mx.Context('cpu', 0)

layer = nn.Dense(1000)
x = mx.nd.random.uniform(shape=(16, 128, 300, 300))
#x = mx.nd.random.uniform(shape=(1, 2, 30, 30)) #This input is ok.
x.attach_grad()
layer.collect_params().initialize()
with mx.autograd.record():
    out = layer(x)
out.backward()
print x.grad.shape
print x.grad.asnumpy().shape

Error message:
print x.grad.asnumpy().shape
File "/home/linliu/mxnet_gpu/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 1876, in asnumpy
ctypes.c_size_t(data.size)))
File "/home/linliu/mxnet_gpu/incubator-mxnet/python/mxnet/base.py", line 149, in check_call
raise MXNetError(py_str(LIB.MXGetLastError()))
mxnet.base.MXNetError: ([04:25:06] include/mxnet/././tensor_blob.h:257: Check failed: this->shape
.Size() == shape.Size() (11520000000 vs. 2930065408) TBlob.get_with_shape: new and old shape do not match total elements

@lionel92 lionel92 changed the title Asnumpy() error with gluon dense under both GPU/CPU/MKLDNN environment Ndarray.asnumpy() error with gluon dense under both GPU/CPU/MKLDNN environment May 4, 2018
@lionel92 lionel92 changed the title Ndarray.asnumpy() error with gluon dense under both GPU/CPU/MKLDNN environment Ndarray.asnumpy() error with gluon dense under both GPU and CPU environment May 4, 2018
@dwSun
Copy link
Contributor

dwSun commented May 4, 2018

modified your script as this:

from mxnet.gluon import nn
import mxnet as mx

mx.Context.default_ctx = mx.Context('cpu', 0)

layer = nn.Dense(1000)
x = mx.nd.random.uniform(shape=(16, 128, 300, 300))
x.attach_grad()
layer.collect_params().initialize()
with mx.autograd.record():
    out = layer(x)
out.backward()
print(x.grad.shape)
print(x.grad)
print(x.grad.asnumpy().shape)

with mxnet-mkl 1.1.0 from pypi, I got this:

% python3 script.py                                                                        134 ↵
(16, 128, 300, 300)
terminate called after throwing an instance of 'std::bad_alloc'
  what():  std::bad_alloc
[1]    17413 abort      python3 script.py

with mxnet-mkl 1.2.0b20180503 from pypi, I got this:

% python3 script.py                                                                          1 ↵
(16, 128, 300, 300)
Traceback (most recent call last):
  File "script.py", line 14, in <module>
    print(x.grad)
  File "/home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 189, in __repr__
    return '\n%s\n<%s %s @%s>' % (str(self.asnumpy()),
  File "/home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 1876, in asnumpy
    ctypes.c_size_t(data.size)))
  File "/home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/base.py", line 149, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [17:24:12] src/storage/./cpu_device_storage.h:73: Failed to allocate CPU Memory

Stack trace returned 10 entries:
[bt] (0) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x17009d) [0x7f39de8c009d]
[bt] (1) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x170468) [0x7f39de8c0468]
[bt] (2) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2dc701d) [0x7f39e151701d]
[bt] (3) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2dc704d) [0x7f39e151704d]
[bt] (4) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2dcc77b) [0x7f39e151c77b]
[bt] (5) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x29140f4) [0x7f39e10640f4]
[bt] (6) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x291469f) [0x7f39e106469f]
[bt] (7) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2914ab0) [0x7f39e1064ab0]
[bt] (8) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2891843) [0x7f39e0fe1843]
[bt] (9) /home/david/.virtualenvs/mkl-dnn/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2899644) [0x7f39e0fe9644]

So, I am totally confused...

@roywei
Copy link
Member

roywei commented May 9, 2018

@sandeep-krishnamurthy could you help to add label NDArray, MKL? Thanks

@juliusshufan
Copy link
Contributor

@roywei Thanks for follow-up, this issue happens on GPU platform as well, so MKL lable might limit the scope. May I have your double-check? Thanks. :)

@zheng-da
Copy link
Contributor

are you sure you can do it in CPU without MKLDNN?
what you are doing here is trying to allocate a weight array whose dimension is 11520000x1000, which is 85GB.
It seems to me that the code fails in initializing the weight matrix with random numbers. It's before calling MKLDNN operators.

@lionel92
Copy link
Author

@zheng-da As I mentioned before, all environments(GPU/CPU/MKLDNN) failed. According to your response, can I consider it as an out-of-memory error?

@zheng-da
Copy link
Contributor

i think so

@piiswrong
Copy link
Contributor

I think this is due to the use of index_t (which is uint32_t) vs int64_t in tesnorblob.

This is a legacy issue. We should use int64_t for all indexing

@apeforest
Copy link
Contributor

Verified the PR #11742 will fix this issue. @sandeep-krishnamurthy please close this. Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

8 participants