Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve message when quantized_dtype uint8 is used with gpu context #14092

Closed
anirudh2290 opened this issue Feb 8, 2019 · 3 comments · Fixed by #14094
Closed

Improve message when quantized_dtype uint8 is used with gpu context #14092

anirudh2290 opened this issue Feb 8, 2019 · 3 comments · Fixed by #14094
Labels
Quantization Issues/Feature Requests related to Quantization

Comments

@anirudh2290
Copy link
Member

Description

I think quantized_dtype uint8 is not supported with GPU. A better error message will be useful

Environment info (Required)

----------Python Info----------
('Version      :', '2.7.12')
('Compiler     :', 'GCC 5.4.0 20160609')
('Build        :', ('default', 'Nov 12 2018 14:36:49'))
('Arch         :', ('64bit', 'ELF'))
------------Pip Info-----------
('Version      :', '18.0')
('Directory    :', '/usr/local/lib/python2.7/dist-packages/pip')
----------MXNet Info-----------
/usr/local/lib/python2.7/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Assertion failure at kmp_runtime.cpp(6481): __kmp_team_pool == __null.
OMP: Error #13: Assertion failure at kmp_runtime.cpp(6481).
OMP: Hint: Please submit a bug report with this message, compile and run commands used, and machine configuration info including native compiler and operating system versions. Faster response will be obtained by including all program sources. For information on submitting this issue, please see https://bugs.llvm.org/.
Assertion failure at kmp_runtime.cpp(6481): __kmp_team_pool == __null.
OMP: Error #13: Assertion failure at kmp_runtime.cpp(6481).
OMP: Hint: Please submit a bug report with this message, compile and run commands used, and machine configuration info including native compiler and operating system versions. Faster response will be obtained by including all program sources. For information on submitting this issue, please see https://bugs.llvm.org/.
Assertion failure at kmp_runtime.cpp(6481): __kmp_team_pool == __null.
OMP: Error #13: Assertion failure at kmp_runtime.cpp(6481).
OMP: Hint: Please submit a bug report with this message, compile and run commands used, and machine configuration info including native compiler and operating system versions. Faster response will be obtained by including all program sources. For information on submitting this issue, please see https://bugs.llvm.org/.
Assertion failure at kmp_runtime.cpp(6481): __kmp_team_pool == __null.
OMP: Error #13: Assertion failure at kmp_runtime.cpp(6481).
OMP: Hint: Please submit a bug report with this message, compile and run commands used, and machine configuration info including native compiler and operating system versions. Faster response will be obtained by including all program sources. For information on submitting this issue, please see https://bugs.llvm.org/.
('Version      :', '1.5.0')
('Directory    :', '/home/ubuntu/experimentals/1.4_release/python/mxnet')
Hashtag not found. Not installed from pre-built package.
----------System Info----------
('Platform     :', 'Linux-4.4.0-1075-aws-x86_64-with-Ubuntu-16.04-xenial')
('system       :', 'Linux')
('node         :', 'ip-172-31-71-199')
('release      :', '4.4.0-1075-aws')
('version      :', '#85-Ubuntu SMP Thu Jan 17 17:15:12 UTC 2019')
----------Hardware Info----------
('machine      :', 'x86_64')
('processor    :', 'x86_64')
Assertion failure at kmp_runtime.cpp(6481): __kmp_team_pool == __null.
OMP: Error #13: Assertion failure at kmp_runtime.cpp(6481).
OMP: Hint: Please submit a bug report with this message, compile and run commands used, and machine configuration info including native compiler and operating system versions. Faster response will be obtained by including all program sources. For information on submitting this issue, please see https://bugs.llvm.org/.
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0160 sec, LOAD: 0.3460 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0035 sec, LOAD: 0.0751 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0150 sec, LOAD: 0.1376 sec.
Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0033 sec, LOAD: 0.0343 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0563 sec, LOAD: 0.4060 sec.
Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.2513 sec, LOAD: 0.3613 sec.

Package used (Python/R/Scala/Julia):
Python

Build info (Required if built from source)

Compiler (gcc/clang/mingw/visual studio): gcc

MXNet commit hash:
a85b3f0

Build config:

cd build && cmake VERBOSE=1 -DUSE_CUDA=ON -DUSE_CUDNN=ON -DUSE_OPENMP=ON -DCMAKE_BUILD_TYPE=Debug -DUSE_DIST_KVSTORE=0 -DUSE_OPENCV=1 -DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda-10.0 -DCUDNN_ROOT=/usr/local/cuda-10.0 -GNinja .. && ninja -v

Error Message:

/home/ubuntu/experimentals/1.4_release/python/mxnet/module/base_module.py:55: UserWarning: You created Module with Module(..., label_names=['softmax_label']) but input with name 'softmax_label' is not found in symbol.list_arguments(). Did you mean one of:
	data
	conv_cudnn_weight_quantize
	conv_cudnn_weight_quantize_min
	conv_cudnn_weight_quantize_max
  warnings.warn(msg)
Traceback (most recent call last):
  File "simple_quantization.py", line 56, in <module>
    final = mod.get_outputs()[0].asnumpy()
  File "/home/ubuntu/experimentals/1.4_release/python/mxnet/ndarray/ndarray.py", line 1995, in asnumpy
    ctypes.c_size_t(data.size)))
  File "/home/ubuntu/experimentals/1.4_release/python/mxnet/base.py", line 252, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [02:07:55] /home/ubuntu/experimentals/1.4_release/src/operator/quantization/../tensor/matrix_op-inl.h:250: Check failed: src.type_flag_ == ret.type_flag_ (3 vs. 5)

Stack trace returned 10 entries:
[bt] (0) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(dmlc::StackTrace[abi:cxx11]()+0x54) [0x7f2ba3664b69]
[bt] (1) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x2a) [0x7f2ba3664e50]
[bt] (2) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(void mxnet::op::TransposeImpl<mshadow::gpu>(mxnet::RunContext, mxnet::TBlob const&, mxnet::TBlob const&, nnvm::TShape const&)+0xd1) [0x7f2ba37c69cc]
[bt] (3) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(mxnet::op::QuantizedCuDNNConvOp<signed char, float, int>::Forward(mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)+0x8c9) [0x7f2ba37c24e3]
[bt] (4) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(mxnet::op::QuantizedConvForwardGPU(nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)+0x3c2) [0x7f2ba37aa345]
[bt] (5) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(std::_Function_handler<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&), void (*)(nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)>::_M_invoke(std::_Any_data const&, nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)+0x91) [0x7f2ba3792cbb]
[bt] (6) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)>::operator()(nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&) const+0xa6) [0x7f2ba6e76178]
[bt] (7) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(mxnet::exec::FComputeExecutor::Run(mxnet::RunContext, bool)+0xa2) [0x7f2ba939eccc]
[bt] (8) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(+0x83c603f) [0x7f2ba93ba03f]
[bt] (9) /home/ubuntu/experimentals/1.4_release/python/mxnet/../../build/libmxnet.so(+0x83c9ed9) [0x7f2ba93bded9]

Minimum reproducible example

import mxnet as mx
from mxnet.io import NDArrayIter
from mxnet.test_utils import DummyIter, assert_almost_equal
import logging
from collections import namedtuple

data_shape = (32, 64, 56, 56)
data = mx.sym.Variable(name="data", shape=data_shape, dtype='float32')
conv_cudnn = mx.sym.Convolution(data=data, kernel=(1, 1), num_filter=256, pad=(0, 0), stride=(1, 1),
                                no_bias=True, layout='NCHW', cudnn_off=False, name="conv_cudnn")
input_data = mx.nd.random.normal(0, 0.2, shape=data_shape, ctx=mx.gpu(0))
conv_weight_name = conv_cudnn.list_arguments()[1]
arg_shapes, _, _ = conv_cudnn.infer_shape(data=data_shape)
mod = mx.mod.Module(conv_cudnn)
mod.bind(for_training=False, data_shapes=[('data', (32, 64, 56, 56))], label_shapes=mod._label_shapes)
mod.init_params()

calib_mode = 'entropy'
num_calib_batches = 1
batch_size = 2
calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1
                                                         or name.find('sc') != -1
                                                         or name.find('fc') != -1)
logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)
calib_data = NDArrayIter(data=input_data)
calib_data = DummyIter(calib_data)
quantized_dtype = 'uint8'

if calib_mode == 'none':
    cqsym, qarg_params, aux_params = mx.contrib.quant.quantize_model(sym=mod._symbol, arg_params=mod._arg_params, aux_params=mod._aux_params,
                                                                     ctx=mx.gpu(0), excluded_sym_names=None,
                                                                     calib_mode=calib_mode, quantized_dtype=quantized_dtype,
                                                                     logger=logger)
else:
    cqsym, qarg_params, aux_params = mx.contrib.quant.quantize_model(sym=mod._symbol, arg_params=mod._arg_params, aux_params=mod._aux_params,
                                                                     ctx=mx.gpu(0), excluded_sym_names=None,
                                                                     calib_mode=calib_mode, calib_data=calib_data,
                                                                     num_calib_examples=num_calib_batches * batch_size,
                                                                     calib_layer=calib_layer, quantized_dtype=quantized_dtype,
                                                                     logger=logger)

mod = mx.mod.Module(cqsym, context=mx.gpu(0))
mod.bind(for_training=False, data_shapes=[('data', (32, 64, 56, 56))], label_shapes=mod._label_shapes)
mod.set_params(qarg_params, aux_params)
Batch = namedtuple('Batch', ['data'])
mod.forward(Batch([input_data]), is_train=False)
final = mod.get_outputs()[0].asnumpy()
@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: Bug

@anirudh2290 anirudh2290 added the Quantization Issues/Feature Requests related to Quantization label Feb 8, 2019
@pengzhao-intel
Copy link
Contributor

@rajeshii do you mind to enhance the error message?

@jitMatrix
Copy link
Contributor

Added:)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Quantization Issues/Feature Requests related to Quantization
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants