Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Race-condition and crash with SymbolBlock on GPU #18765

Closed
chinakook opened this issue Jul 21, 2020 · 7 comments · Fixed by #18768
Closed

Race-condition and crash with SymbolBlock on GPU #18765

chinakook opened this issue Jul 21, 2020 · 7 comments · Fixed by #18768
Labels
Bug v1.x Targeting v1.x branch v2.0

Comments

@chinakook
Copy link
Contributor

Description

Severe Bug with nn.SymbolBlock when ctx=mx.gpu(0), cpu is OK.

Error Message

malloc or free or Segmentation fault error may appears randomly

/home/xxxxxx/anaconda3/envs/solo/lib/python3.7/site-packages/mxnet/gluon/block.py:1517: UserWarning: Cannot decide type for the following arguments. Consider providing them as input:
        data: None
  input_sym_arg_type = in_param.infer_type()[0]
[17:15:59] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
[(1, 256, 56, 56), (1, 512, 28, 28), (1, 1024, 14, 14), (1, 2048, 7, 7)]
malloc(): unsorted double linked list corrupted
[1]    87116 abort (core dumped)  python symbolblockbug.py

/home/xxxxxx/anaconda3/envs/solo/lib/python3.7/site-packages/mxnet/gluon/block.py:1517: UserWarning: Cannot decide type for the following arguments. Consider providing them as input:
        data: None
  input_sym_arg_type = in_param.infer_type()[0]
[17:21:29] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
[(1, 256, 56, 56), (1, 512, 28, 28), (1, 1024, 14, 14), (1, 2048, 7, 7)]

Segmentation fault: 11

/home/xxxxxx/anaconda3/envs/solo/lib/python3.7/site-packages/mxnet/gluon/block.py:1517: UserWarning: Cannot decide type for the following arguments. Consider providing them as input:
        data: None
  input_sym_arg_type = in_param.infer_type()[0]
[17:23:24] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
[(1, 256, 56, 56), (1, 512, 28, 28), (1, 1024, 14, 14), (1, 2048, 7, 7)]
malloc_consolidate(): invalid chunk size
[1]    87701 abort (core dumped)  python symbolblockbug.py

To Reproduce

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
import gluoncv as gcv
class NetEncoder(nn.SymbolBlock):
    def __init__(self, **kwargs):
        base_network = gcv.model_zoo.resnet50_v1(pretrained=False)
        outputs = ['stage1_activation2', 'stage2_activation3', 'stage3_activation5',
                            'stage4_activation2']

        inputs, outputs, params = gcv.nn.feature._parse_network(
            base_network, outputs, ['data'], pretrained=False, ctx=mx.cpu(), **kwargs)
        super(NetEncoder, self).__init__(outputs, inputs, params=params)
    
class Foo(nn.HybridBlock):
    def __init__(self):
        super(Foo, self).__init__()
        self.features = NetEncoder()

    def hybrid_forward(self, F, x):
        y = self.features(x)
        return y

a = mx.nd.random.uniform(shape=(1,3,224,224), ctx=mx.gpu(0))

f = Foo()
f.collect_params().initialize()
f.hybridize()
f.collect_params().reset_ctx(mx.gpu(0))
b = f(a)
print([x.shape for x in b])

Environment

  1. mxnet_cu102-1.7.0b20200719-py2.py3-none-manylinux2014_x86_64
  2. mxnet 2.0 master in April
@chinakook
Copy link
Contributor Author

Solved. Change the last row to print([x.asnumpy().shape for x in b]).

@leezu
Copy link
Contributor

leezu commented Jul 21, 2020

@chinakook thanks for providing a workaround. I think it's still a bug

@leezu leezu reopened this Jul 21, 2020
@leezu
Copy link
Contributor

leezu commented Jul 21, 2020

Simpler reproducible example for latest master:

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn

a = mx.nd.random.uniform(shape=(1,3,224,224))

backbone = gluon.model_zoo.vision.resnet18_v1()
backbone.initialize()
backbone.hybridize()

backbone(a)

# Alternative:
# backbone.reset_ctx(mx.gpu(0))
# b = backbone(a.as_in_context(mx.gpu(0)))
# print([x.shape for x in b])

sym_file, params_file = backbone.export('/tmp/model')

f = gluon.SymbolBlock.imports(sym_file, 'data', params_file)
f.reset_ctx(mx.gpu(0))
b = f(a.as_in_context(mx.gpu(0)))
print([x.shape for x in b])

It fails with

[18:59:34] ../src/storage/storage.cc:198: Using Pooled (Naive) StorageManager for CPU
/home/ubuntu/src/mxnet-master/python/mxnet/gluon/block.py:1723: UserWarning: Cannot decide type for the following arguments. Consider providing them as input:
        data: None
  input_sym_arg_type = in_param.infer_type()[0]
[18:59:37] ../src/storage/storage.cc:198: Using Pooled (Naive) StorageManager for GPU
[(1000,)]
[18:59:38] ../src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)

Segmentation fault: 11

zsh: abort (core dumped)  python3 symbolblockbug.py

However, the following will work:


import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn

a = mx.nd.random.uniform(shape=(1,3,224,224))

backbone = gluon.model_zoo.vision.resnet18_v1()
backbone.initialize()
backbone.hybridize()

# backbone(a)

# Alternative:
backbone.reset_ctx(mx.gpu(0))
b = backbone(a.as_in_context(mx.gpu(0)))
print([x.shape for x in b])

sym_file, params_file = backbone.export('/tmp/model')

f = gluon.SymbolBlock.imports(sym_file, 'data', params_file)
f.reset_ctx(mx.gpu(0))
b = f(a.as_in_context(mx.gpu(0)))
print([x.shape for x in b])

@leezu leezu changed the title Severe Bug with nn.SymbolBlock Race-condition and crash with SymbolBlock on GPU Jul 21, 2020
@leezu leezu added v2.0 v1.x Targeting v1.x branch and removed needs triage labels Jul 21, 2020
@leezu
Copy link
Contributor

leezu commented Jul 21, 2020

Backtrace

#0  __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:51
#1  0x00007ffff70608b1 in __GI_abort () at abort.c:79
#2  0x00007ffff70a9907 in __libc_message (action=action@entry=do_abort, fmt=fmt@entry=0x7ffff71d6dfa "%s\n") at ../sysdeps/posix/libc_fatal.c:181
#3  0x00007ffff70b097a in malloc_printerr (str=str@entry=0x7ffff71d4fe8 "free(): invalid pointer") at malloc.c:5350
#4  0x00007ffff70b7e8c in _int_free (have_lock=0, p=0x7ffcc0a783d8, av=0x7ffff740bc40 <main_arena>) at malloc.c:4157
#5  __GI___libc_free (mem=0x7ffcc0a783e8) at malloc.c:3124
#6  0x00007fff461d3d46 in __gnu_cxx::new_allocator<char>::deallocate (this=0x7ffd81ffde10, __p=0x7ffcc0a783e8 "") at /usr/include/c++/7/ext/new_allocator.h:125
#7  0x00007fff461d2dc3 in std::allocator_traits<std::allocator<char> >::deallocate (__a=..., __p=0x7ffcc0a783e8 "", __n=305) at /usr/include/c++/7/bits/alloc_traits.h:462
#8  0x00007fff461d2178 in std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >::_M_destroy (this=0x7ffd81ffde10, __size=304)
    at /usr/include/c++/7/bits/basic_string.h:226
#9  0x00007fff461d187a in std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >::_M_dispose (this=0x7ffd81ffde10)
    at /usr/include/c++/7/bits/basic_string.h:221
#10 0x00007fff461d0318 in std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >::~basic_string (this=0x7ffd81ffde10, __in_chrg=<optimized out>)
    at /usr/include/c++/7/bits/basic_string.h:647
#11 0x00007fff5204c784 in mxnet::FusedOp::CompileCode (this=0x555559e922d0,
    code="using DType_output0 = float;\nstatic const int ndim_output0 = 4;\nstatic const int ndim_input_1 = 4;\nusing DType_input_1 = float;\nstatic const int ndim_input_0 = 4;\nusing DType_input_0 = float;\nstatic c"..., kernel_name="elemwise_add_Activation", dev_id=0) at ../src/operator/fusion/fused_op.cu:651
#12 0x00007fff5204df75 in mxnet::FusedOp::Forward<mshadow::gpu> (this=0x555559e922d0, attrs=..., ctx=..., inputs=std::vector of length 2, capacity 2 = {...},
    req=std::vector of length 1, capacity 1 = {...}, outputs=std::vector of length 1, capacity 1 = {...}) at ../src/operator/fusion/fused_op.cu:766
#13 0x00007fff5204eac1 in mxnet::FusedOpForwardGPU (attrs=..., ctx=..., inputs=std::vector of length 2, capacity 2 = {...}, req=std::vector of length 1, capacity 1 = {...},
    outputs=std::vector of length 1, capacity 1 = {...}) at ../src/operator/fusion/fused_op.cu:836
#14 0x00007fff466b3bfb in std::_Function_handler<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&), void (*)(nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)>::_M_invoke(std::_Any_data const&, nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&) (__functor=..., __args#0=..., __args#1=...,
    __args#2=std::vector of length 2, capacity 2 = {...}, __args#3=std::vector of length 1, capacity 1 = {...}, __args#4=std::vector of length 1, capacity 1 = {...})
    at /usr/include/c++/7/bits/std_function.h:316
#15 0x00007fff464f09b4 in std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)>::operator()(nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&) const (this=0x555559e6eb20, __args#0=..., __args#1=..., __args#2=std::vector of length 2, capacity 2 = {...}, __args#3=std::vector of length 1, capacity 1 = {...},
    __args#4=std::vector of length 1, capacity 1 = {...}) at /usr/include/c++/7/bits/std_function.h:706
#16 0x00007fff4656a8d6 in mxnet::imperative::PushFCompute(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)> const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&)::{lambda(mxnet::RunContext)#1}::operator()(mxnet::RunContext) const (__closure=0x555559e6ea90, rctx=...) at ../src/imperative/./imperative_utils.h:494
#17 0x00007fff46572f70 in std::_Function_handler<void (mxnet::RunContext), mxnet::imperative::PushFCompute(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::Op---Type <return> to continue, or q <return>
to quit---
:engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&)::{lambda(mxnet::RunContext)#1}>::_M_invoke(std::_Any_data const&, mxnet::RunContext&&) (__functor=...,
    __args#0=...) at /usr/include/c++/7/bits/std_function.h:316
#18 0x00007fff464bd868 in std::function<void (mxnet::RunContext)>::operator()(mxnet::RunContext) const (this=0x555555edb1e0, __args#0=...) at /usr/include/c++/7/bits/std_function.h:706
#19 0x00007fff464c96f6 in mxnet::engine::ThreadedEngine::BulkFlush()::{lambda(mxnet::RunContext, mxnet::engine::CallbackOnComplete)#1}::operator()(mxnet::RunContext, mxnet::engine::CallbackOnComplete) const (__closure=0x555559e374b0, ctx=..., on_complete=...)
    at ../src/engine/./threaded_engine.h:537
#20 0x00007fff464ce27b in std::_Function_handler<void (mxnet::RunContext, mxnet::engine::CallbackOnComplete), mxnet::engine::ThreadedEngine::BulkFlush()::{lambda(mxnet::RunContext, mxnet::engine::CallbackOnComplete)#1}>::_M_invoke(std::_Any_data const&, mxnet::RunContext&&, mxnet::engine::CallbackOnComplete&&) (__functor=..., __args#0=..., __args#1=...) at /usr/include/c++/7/bits/std_function.h:316
#21 0x00007fff464be704 in std::function<void (mxnet::RunContext, mxnet::engine::CallbackOnComplete)>::operator()(mxnet::RunContext, mxnet::engine::CallbackOnComplete) const (this=0x555557eedc60, __args#0=..., __args#1=...)
    at /usr/include/c++/7/bits/std_function.h:706
#22 0x00007fff464d70e4 in mxnet::engine::ThreadedEngine::ExecuteOprBlock (this=0x555556fb6980, run_ctx=..., opr_block=0x5555580bd9b0) at ../src/engine/./threaded_engine.h:381
#23 0x00007fff464dcc57 in mxnet::engine::ThreadedEnginePerDevice::GPUWorker<(dmlc::ConcurrentQueueType)0> (this=0x555556fb6980, ctx=..., is_copy_worker=false, block=0x5555570b5b40,
    ready_event=std::shared_ptr<dmlc::ManualEvent> (use count 2, weak count 0) = {...}) at ../src/engine/threaded_engine_perdevice.cc:272
#24 0x00007fff464d8baa in mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#4}::operator()() const::{lambda(std::shared_ptr<dmlc::ManualEvent>)#1}::operator()(dmlc::ManualEvent) const (__closure=0x5555580e8e20,
    ready_event=std::shared_ptr<dmlc::ManualEvent> (use count 2, weak count 0) = {...}) at ../src/engine/threaded_engine_perdevice.cc:186
#25 0x00007fff464dfc9d in std::_Function_handler<void (std::shared_ptr<dmlc::ManualEvent>), mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#4}::operator()() const::{lambda(std::shared_ptr<dmlc::ManualEvent>)#1}>::_M_invoke(std::_Any_data const&, std::shared_ptr<dmlc::ManualEvent>&&) (__functor=..., __args#0=...) at /usr/include/c++/7/bits/std_function.h:316
#26 0x00007fff464e088f in std::function<void (std::shared_ptr<dmlc::ManualEvent>)>::operator()(std::shared_ptr<dmlc::ManualEvent>) const (this=0x555558040268, __args#0=std::shared_ptr<dmlc::ManualEvent> (empty) = {...})
    at /usr/include/c++/7/bits/std_function.h:706
#27 0x00007fff464ddfbf in std::__invoke_impl<void, std::function<void (std::shared_ptr<dmlc::ManualEvent>)>, std::shared_ptr<dmlc::ManualEvent> >(std::__invoke_other, std::function<void (std::shared_ptr<dmlc::ManualEvent>)>&&, std::shared_ptr<dmlc::ManualEvent>&&) (__f=...) at /usr/include/c++/7/bits/invoke.h:60
#28 0x00007fff464d9b7d in std::__invoke<std::function<void (std::shared_ptr<dmlc::ManualEvent>)>, std::shared_ptr<dmlc::ManualEvent> >(std::function<void (std::shared_ptr<dmlc::ManualEvent>)>&&, std::shared_ptr<dmlc::ManualEvent>&&) (__fn=...)
    at /usr/include/c++/7/bits/invoke.h:95
#29 0x00007fff464e746b in std::thread::_Invoker<std::tuple<std::function<void (std::shared_ptr<dmlc::ManualEvent>)>, std::shared_ptr<dmlc::ManualEvent> > >::_M_invoke<0ul, 1ul>(std::_Index_tuple<0ul, 1ul>) (this=0x555558040258) at /usr/include/c++/7/thread:234
#30 0x00007fff464e73d3 in std::thread::_Invoker<std::tuple<std::function<void (std::shared_ptr<dmlc::ManualEvent>)>, std::shared_ptr<dmlc::ManualEvent> > >::operator()() (this=0x555558040258) at /usr/include/c++/7/thread:243
#31 0x00007fff464e7372 in std::thread::_State_impl<std::thread::_Invoker<std::tuple<std::function<void (std::shared_ptr<dmlc::ManualEvent>)>, std::shared_ptr<dmlc::ManualEvent> > > >::_M_run() (this=0x555558040250) at /usr/include/c++/7/thread:186
#32 0x00007ffff070a6df in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#33 0x00007ffff7bbd6db in start_thread (arg=0x7ffd81fff700) at pthread_create.c:463
#34 0x00007ffff7141a3f in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95

@ptrendx do you think this is a bug in fused op or in the engine?

@eric-haibin-lin
Copy link
Member

does adding mx.nd.waitall() help?

@chinakook
Copy link
Contributor Author

@eric-haibin-lin Yes, it does.

@chinakook
Copy link
Contributor Author

As I've tested, #18768 can solve this problem.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Bug v1.x Targeting v1.x branch v2.0
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants