Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Out of memory error in 3d Conv for matrix splits > 10, CUDNN strange behaviour #14029

Closed
Vikas-kum opened this issue Jan 30, 2019 · 9 comments
Closed

Comments

@Vikas-kum
Copy link
Contributor

Vikas-kum commented Jan 30, 2019

Description

Memory bloat(OOM) in 3D Conv when matrix split size is greater than 10 . If I run this code on ec2 p2.xl with 12 GB of GPU memory, program runs well for
a = net(x[:, :, :, :10, :10])
print(a.shape)

a = net(x[:, :, :, :9, :9])
print(a.shape)

but starts getting cuda OOM error for :

a = net(x[:, :, :, :11, :11])
print(a.shape)

import os
os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0'

import mxnet as mx

from mxnet.gluon import nn

# from resnet_i3d import BasicBlockV1
def _conv3x3(channels, stride, in_channels):
    return nn.Conv3D(channels, kernel_size=3, strides=stride, padding=1, use_bias=False, in_channels=in_channels)
# Blocks
class BasicBlockV1(mx.gluon.HybridBlock):
    r"""BasicBlock V1 from `"Deep Residual Learning for Image Recognition"
    <http://arxiv.org/abs/1512.03385>`_ paper.
    This is used for ResNet V1 for 18, 34 layers.

    Parameters
    ----------
    channels : int
        Number of output channels.
    stride : int
        Stride size.
    downsample : bool, default False
        Whether to downsample the input.
    in_channels : int, default 0
        Number of input channels. Default is 0, to infer from the graph.
    """
    def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs):
        super(BasicBlockV1, self).__init__(**kwargs)
        with self.name_scope():
            self.body = nn.HybridSequential(prefix='')
            self.body.add(_conv3x3(channels, stride, in_channels))
            self.body.add(nn.BatchNorm())
            self.body.add(nn.Activation('relu'))
            self.body.add(_conv3x3(channels, 1, channels))
            self.body.add(nn.BatchNorm())
            if downsample:
                self.downsample = nn.HybridSequential(prefix='')
                self.downsample.add(nn.Conv3D(channels, kernel_size=1, strides=stride, use_bias=False, in_channels=in_channels))
                self.downsample.add(nn.BatchNorm())
            else:
                self.downsample = None

    def hybrid_forward(self, F, x):
        residual = x
        x = self.body(x)
        if self.downsample:
            residual = self.downsample(residual)
        x = F.Activation(residual+x, act_type='relu')
        return x


ctx = mx.gpu(0)

net = nn.HybridSequential(prefix='')
channels = [256, 512]
net.add(BasicBlockV1(channels[-1], 1, downsample=True, in_channels=channels[-2], prefix=''))
net.add(BasicBlockV1(channels[-1], 1, False, in_channels=channels[-1], prefix=''))
net.initialize(ctx=ctx)

x = mx.nd.random.normal(0, 1, (300, 256, 2, 14, 14), ctx=ctx)

a = net(x[:, :, :, :10, :10])
print(a.shape)

a = net(x[:, :, :, :9, :9])
print(a.shape)

b = net(x[:, :, :, :11, :11])
print(b.shape)

Environment info (Required)

mxnet-cu92==1.3.1, gluoncv==0.3.0
Cuda 9.2, cudnn 7.1
instance used : p2.xl on ec2 , 12 GB of GPU memory

What to do:
copy script as given above to my_test.py
run on p2.xl using python my_test.py

@Vikas-kum
Copy link
Contributor Author

Thanks, I was able to repro the issue.

With slice of 11, looks like it is going through different cuda code.
With 11 this is where branching occurs:

_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_13slice_forwardILi5ELi1EN7mshadow3gpuEEEJPfS7_NS4_5ShapeILi5EEES9_NS_6common11StaticArrayIiLi5EEESC_EEEviDpT0_ [731]
3.32607s  42.809ms         (131072 1 1)       (16 16 1)        56  10.160KB        0B         -           -           -           -    Tesla K80 (0)         1        14  void fft3d_r2c_16x16x16<float, float, float2>(float2*, float*, int3, int3, int3, int3, int3, bool) [766]
3.36884s  25.372ms          (76800 1 1)       (16 16 1)        56  10.160KB        0B         -           -           -           -    Tesla K80 (0)         1        18  void fft3d_r2c_16x16x16<float, float, float2>(float2*, float*, int3, int3, int3, int3, int3, bool) [786]
3.39416s  53.754ms          (2048 72 1)       (256 1 1)        46  8.1250KB        0B         -           -           -           -    Tesla K80 (0)         1        14  void transpose_readWrite_alignment_kernel<float2, float2, int=1, bool=0, int=6, int=4, int=4>(cublasTransposeParams<float2>, float2 const *, float2*, float2 const *) [774] ...

See details below for gpu stack trace for :10,:10 vs :11,:11

Details of gpu allocation with a = net(x[:, :, :, :10, :10])

[01:26:50] src/engine/engine.cc:55: MXNet start using engine: NaiveEngine
==21970== NVPROF is profiling process 21970, command: python bb.py
(300, 512, 2, 10, 10)
==21970== Profiling application: python bb.py
[01:26:56] src/engine/naive_engine.cc:55: Engine shutdown
==21970== Profiling result:
   Start  Duration            Grid Size      Block Size     Regs*    SSMem*    DSMem*      Size  Throughput  SrcMemType  DstMemType           Device   Context    Stream  Name
2.69533s  2.3040us                    -               -         -         -         -      112B  46.359MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.02670s  2.0480us                    -               -         -         -         -      112B  52.154MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.02685s  1.5680us                    -               -         -         -         -      112B  68.120MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.02702s  1.5680us                    -               -         -         -         -      112B  68.120MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.02718s  1.5680us                    -               -         -         -         -      112B  68.120MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.02776s  4.8000us                    -               -         -         -         -  2.3750KB  483.19MB/s      Device           -    Tesla K80 (0)         1        26  [CUDA memset]
3.02778s  3.7440us                    -               -         -         -         -  2.3750KB  619.48MB/s      Device           -    Tesla K80 (0)         1        27  [CUDA memset]
3.02780s  3.7120us                    -               -         -         -         -  2.3750KB  624.82MB/s      Device           -    Tesla K80 (0)         1        28  [CUDA memset]
3.02782s  3.7120us                    -               -         -         -         -  2.3750KB  624.82MB/s      Device           -    Tesla K80 (0)         1        29  [CUDA memset]
3.02861s  2.1120us                    -               -         -         -         -      112B  50.574MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.02959s  1.8464ms                    -               -         -         -         -  13.500MB  7.1400GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.03217s  131.20us          (13824 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [594]
3.07020s  3.6699ms                    -               -         -         -         -  27.000MB  7.1848GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.07460s  257.12us          (27648 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [610]
3.07704s  71.007us                    -               -         -         -         -  512.00KB  6.8765GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.07748s  6.9440us            (512 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [626]
3.12820s  3.7064ms                    -               -         -         -         -  27.000MB  7.1141GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.13270s  257.18us          (27648 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [642]
3.16837s  3.6899ms                    -               -         -         -         -  27.000MB  7.1458GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.17289s  257.18us          (27648 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [658]
3.17373s  81.503us            (128 1 1)       (256 1 1)        24        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  mxnet::common::random::rand_generator_seed_kernel(curandStatePhilox4_32_10*, int, unsigned int) [667]
3.17402s  85.599us            (128 1 1)       (256 1 1)        24        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  mxnet::common::random::rand_generator_seed_kernel(curandStatePhilox4_32_10*, int, unsigned int) [676]
3.17430s  84.639us            (128 1 1)       (256 1 1)        24        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  mxnet::common::random::rand_generator_seed_kernel(curandStatePhilox4_32_10*, int, unsigned int) [685]
3.17458s  84.127us            (128 1 1)       (256 1 1)        24        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  mxnet::common::random::rand_generator_seed_kernel(curandStatePhilox4_32_10*, int, unsigned int) [694]
3.17542s  2.4320us                    -               -         -         -         -        4B  1.5685MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.17543s  1.5680us                    -               -         -         -         -        4B  2.4328MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.17547s  50.267ms            (128 1 1)       (256 1 1)        32        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_18SampleNormalKernelIN7mshadow3gpuEEEJNS_6common6random13RandGeneratorIS5_fEEiijjPfSB_SB_EEEviDpT0_ [716]
3.22654s  3.6994ms          (60000 1 1)       (256 1 1)        25        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_13slice_forwardILi5ELi1EN7mshadow3gpuEEEJPfS7_NS4_5ShapeILi5EEES9_NS_6common11StaticArrayIiLi5EEESC_EEEviDpT0_ [731]
3.23131s  329.61ms            (938 2 1)        (8 32 1)       108  10.250KB        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *) [758]
3.56288s  2.2720us                    -               -         -         -         -  2.0000KB  859.65MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.56324s  2.1430us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [775]
3.56381s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.56415s  1.9840us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [791]
3.56472s  2.0160us                    -               -         -         -         -  2.0000KB  968.81MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.56525s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.56639s  6.4328ms            (512 1 1)       (256 1 1)        36        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>(int=1, mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>, float, int=1, int=1, int=1, int=1, int=1, float, unsigned int) [827]
3.57302s  1.8631ms           (7500 1 1)       (128 1 1)        20        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, float const *, cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*) [848]
3.57518s  674.13ms            (938 2 1)        (8 32 1)       108  10.250KB        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *) [875]
4.25052s  2.3040us                    -               -         -         -         -  2.0000KB  847.71MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.25068s  3.2000us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [890]
4.25129s  2.0800us                    -               -         -         -         -  2.0000KB  939.00MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.25162s  1.9840us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [906]
4.25218s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.25273s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.25329s  6.4278ms            (512 1 1)       (256 1 1)        36        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>(int=1, mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>, float, int=1, int=1, int=1, int=1, int=1, float, unsigned int) [940]
4.26006s  13.323ms            (938 2 1)        (8 32 1)       108  10.250KB        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *) [972]
4.27427s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.27440s  2.4320us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [985]
4.27495s  1.9520us                    -               -         -         -         -  2.0000KB  0.9771GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.27528s  1.9840us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [1001]
4.27584s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.27636s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.27751s  6.4224ms            (512 1 1)       (256 1 1)        36        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>(int=1, mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>, float, int=1, int=1, int=1, int=1, int=1, float, unsigned int) [1037]
4.28411s  2.9166ms          (65535 1 1)       (256 1 1)        10        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS0_10mshadow_op4plusELi1EEEJPfS7_S7_EEEviDpT0_ [1054]
4.28768s  1.8733ms           (7500 1 1)       (128 1 1)        20        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, float const *, cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*) [1070]
4.28985s  672.74ms            (938 2 1)        (8 32 1)       108  10.250KB        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *) [1098]
4.96379s  2.4000us                    -               -         -         -         -  2.0000KB  813.80MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.96393s  3.2960us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [1111]
4.96454s  2.0480us                    -               -         -         -         -  2.0000KB  953.67MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.96489s  1.9840us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [1127]
4.96545s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.96597s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.96653s  6.4416ms            (512 1 1)       (256 1 1)        36        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>(int=1, mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>, float, int=1, int=1, int=1, int=1, int=1, float, unsigned int) [1161]
4.97314s  1.8597ms           (7500 1 1)       (128 1 1)        20        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, float const *, cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*) [1182]
4.97525s  673.72ms            (938 2 1)        (8 32 1)       108  10.250KB        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>(int, int, int, float const *, int, cudnn::detail::implicit_convolveNd_sgemm<float, int=3, int=512, int=6, int=8, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>*, float const *, kernel_convNd_params, int, float, float, int, float const *, float const *) [1206]
5.65040s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
5.65055s  2.5600us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [1221]
5.65116s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
5.65149s  1.9520us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [1237]
5.65205s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
5.65257s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
5.65322s  6.4391ms            (512 1 1)       (256 1 1)        36        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>(int=1, mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>, float, int=1, int=1, int=1, int=1, int=1, float, unsigned int) [1271]
5.65984s  2.9161ms          (65535 1 1)       (256 1 1)        10        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS0_10mshadow_op4plusELi1EEEJPfS7_S7_EEEviDpT0_ [1288]
5.66288s  1.8583ms           (7500 1 1)       (128 1 1)        20        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, float const *, cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*) [1302]



With a = net(x[:, :, :, :11, :11])

==22076== Profiling application: python bb.py
[01:27:32] src/engine/naive_engine.cc:55: Engine shutdown
==22076== Profiling result:
   Start  Duration            Grid Size      Block Size     Regs*    SSMem*    DSMem*      Size  Throughput  SrcMemType  DstMemType           Device   Context    Stream  Name
2.70511s  2.3040us                    -               -         -         -         -      112B  46.359MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.03518s  2.0160us                    -               -         -         -         -      112B  52.982MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.03530s  1.5680us                    -               -         -         -         -      112B  68.120MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.03547s  1.5680us                    -               -         -         -         -      112B  68.120MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.03563s  1.5360us                    -               -         -         -         -      112B  69.539MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.03622s  4.9280us                    -               -         -         -         -  2.3750KB  470.64MB/s      Device           -    Tesla K80 (0)         1        26  [CUDA memset]
3.03625s  3.9360us                    -               -         -         -         -  2.3750KB  589.26MB/s      Device           -    Tesla K80 (0)         1        27  [CUDA memset]
3.03626s  3.7440us                    -               -         -         -         -  2.3750KB  619.48MB/s      Device           -    Tesla K80 (0)         1        28  [CUDA memset]
3.03628s  3.7120us                    -               -         -         -         -  2.3750KB  624.82MB/s      Device           -    Tesla K80 (0)         1        29  [CUDA memset]
3.03714s  2.3360us                    -               -         -         -         -      112B  45.724MB/s    Pageable      Device    Tesla K80 (0)         1         7  [CUDA memcpy HtoD]
3.03808s  1.8467ms                    -               -         -         -         -  13.500MB  7.1389GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.04068s  130.24us          (13824 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [594]
3.10729s  3.6815ms                    -               -         -         -         -  27.000MB  7.1621GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.11170s  256.96us          (27648 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [610]
3.11457s  71.168us                    -               -         -         -         -  512.00KB  6.8610GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.11509s  6.9760us            (512 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [626]
3.18066s  3.7273ms                    -               -         -         -         -  27.000MB  7.0742GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.18521s  256.99us          (27648 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [642]
3.24959s  3.6985ms                    -               -         -         -         -  27.000MB  7.1292GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.25406s  257.05us          (27648 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [658]
3.25491s  79.680us            (128 1 1)       (256 1 1)        24        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  mxnet::common::random::rand_generator_seed_kernel(curandStatePhilox4_32_10*, int, unsigned int) [667]
3.25519s  82.495us            (128 1 1)       (256 1 1)        24        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  mxnet::common::random::rand_generator_seed_kernel(curandStatePhilox4_32_10*, int, unsigned int) [676]
3.25547s  81.023us            (128 1 1)       (256 1 1)        24        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  mxnet::common::random::rand_generator_seed_kernel(curandStatePhilox4_32_10*, int, unsigned int) [685]
3.25575s  82.688us            (128 1 1)       (256 1 1)        24        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  mxnet::common::random::rand_generator_seed_kernel(curandStatePhilox4_32_10*, int, unsigned int) [694]
3.25662s  2.4640us                    -               -         -         -         -        4B  1.5482MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.25663s  1.5680us                    -               -         -         -         -        4B  2.4328MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
3.25667s  50.391ms            (128 1 1)       (256 1 1)        32        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_18SampleNormalKernelIN7mshadow3gpuEEEJNS_6common6random13RandGeneratorIS5_fEEiijjPfSB_SB_EEEviDpT0_ [716]
3.30790s  4.4087ms          (65535 1 1)       (256 1 1)        25        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_13slice_forwardILi5ELi1EN7mshadow3gpuEEEJPfS7_NS4_5ShapeILi5EEES9_NS_6common11StaticArrayIiLi5EEESC_EEEviDpT0_ [731]
3.32607s  42.809ms         (131072 1 1)       (16 16 1)        56  10.160KB        0B         -           -           -           -    Tesla K80 (0)         1        14  void fft3d_r2c_16x16x16<float, float, float2>(float2*, float*, int3, int3, int3, int3, int3, bool) [766]
3.36884s  25.372ms          (76800 1 1)       (16 16 1)        56  10.160KB        0B         -           -           -           -    Tesla K80 (0)         1        18  void fft3d_r2c_16x16x16<float, float, float2>(float2*, float*, int3, int3, int3, int3, int3, bool) [786]
3.39416s  53.754ms          (2048 72 1)       (256 1 1)        46  8.1250KB        0B         -           -           -           -    Tesla K80 (0)         1        14  void transpose_readWrite_alignment_kernel<float2, float2, int=1, bool=0, int=6, int=4, int=4>(cublasTransposeParams<float2>, float2 const *, float2*, float2 const *) [774]
3.44786s  31.851ms          (1200 72 1)       (256 1 1)        46  8.1250KB        0B         -           -           -           -    Tesla K80 (0)         1        18  void transpose_readWrite_alignment_kernel<float2, float2, int=1, bool=0, int=6, int=4, int=4>(cublasTransposeParams<float2>, float2 const *, float2*, float2 const *) [794]
3.47972s  431.09ms           (8 5 2304)       (16 16 1)       127  8.1445KB        0B         -           -           -           -    Tesla K80 (0)         1        18  cgemm_strided_batched_sm35_ldg_tn_64x8x64x16x16 [816]
3.91082s  122.35ms          (36 4800 1)       (256 1 1)        46  8.1250KB        0B         -           -           -           -    Tesla K80 (0)         1        18  void transpose_readWrite_alignment_kernel<float2, float2, int=1, bool=0, int=6, int=4, int=4>(cublasTransposeParams<float2>, float2 const *, float2*, float2 const *) [824]
4.03318s  46.092ms         (153600 1 1)       (16 16 1)        59  10.160KB        0B         -           -           -           -    Tesla K80 (0)         1        18  void fft3d_c2r_16x16x16<float2, float, float>(float*, float2*, int3, int3, int3, int3, int3, float, float, bool, int, float*, float*) [840]
4.09256s  2.6560us                    -               -         -         -         -  2.0000KB  735.36MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.09280s  2.5920us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [873]
4.09326s  1.9840us                    -               -         -         -         -  2.0000KB  984.44MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.09347s  2.0160us              (2 1 1)       (256 1 1)         8        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_ [889]
4.09391s  1.9830us                    -               -         -         -         -  2.0000KB  984.93MB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.09436s  1.9520us                    -               -         -         -         -  2.0000KB  0.9771GB/s    Pageable      Device    Tesla K80 (0)         1        14  [CUDA memcpy HtoD]
4.09517s  7.2510ms            (512 1 1)       (256 1 1)        36        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>(int=1, mxnet::op::batchnorm::cuda::BatchNormalizationUpdateOutputInferenceKernel<float, float, mxnet::op::batchnorm::cuda::DeviceTensor<float, int=1>, mxnet::op::batchnorm::BNTensor3<float>>, float, int=1, int=1, int=1, int=1, int=1, float, unsigned int) [925]
4.10261s  2.3114ms           (9600 1 1)       (128 1 1)        20        0B        0B         -           -           -           -    Tesla K80 (0)         1        14  void cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, float const *, cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*) [946]

Regs: Number of registers used per CUDA thread. This number includes registers used internally by the CUDA driver and/or tools and can be more than what the compiler shows.
SSMem: Static shared memory allocated per CUDA block.
DSMem: Dynamic shared memory allocated per CUDA block.
SrcMemType: The type of source memory accessed by memory operation/copy
DstMemType: The type of destination memory accessed by memory operation/copy
======== Error: Application returned non-zero code 1

@Vikas-kum
Copy link
Contributor Author

This looks like cuDNN implementation bug,
Looks like different code inside cudnn is being called for cases > 10

This is for case 11, void fft3d_r2c_16x16x16 is called , for case 10 or less : void cudnn::detail::implicit_convolveNd_sgemm is called.
We verified that disabling cuDNN during convolution fixes the problem.

@ptrendx @DickJC123 Would you guys be able to help here in case you have any idea about this behavior of cuDNN or if you can point this to right people.

@DickJC123
Copy link
Contributor

We will investigate and loop-in the cudnn team.

@vdantu
Copy link
Contributor

vdantu commented Jan 30, 2019

@mxnet-label-bot add [Bug, Cuda, memory]

@vdantu
Copy link
Contributor

vdantu commented Jan 30, 2019

@mxnet-label-bot remove [Cuda]

@marcoabreu marcoabreu removed the CUDA label Jan 30, 2019
@DickJC123
Copy link
Contributor

DickJC123 commented Jan 31, 2019

I was able to repro this OOM on a 12G Pascal. At the point of failure, it was asking for a 10G temporary workspace! Since you've set MXNET_CUDNN_AUTOTUNE_DEFAULT=0, doesn't that say you're willing to accept what cudnnGet() returns for the algo, regardless of workspace needs?

@Vikas-kum
Copy link
Contributor Author

" Since you've set MXNET_CUDNN_AUTOTUNE_DEFAULT=0, doesn't that say you're willing to accept what cudnnGet() returns for the algo, regardless of workspace needs?"

Yes. But to me, it's still strange that amount of memory required when going from 10 to 11 increases by such huge factor and I wanted to check with cudnn team if this is expected and known behavior of cudnn or there is some bug causing this memory bloat.

@DickJC123
Copy link
Contributor

DickJC123 commented Jan 31, 2019

As pointed out earlier, going from 10 to 11 is the threshhold for when cudnn thinks the fft implementation is fastest. That algo apparently has a huge workspace requirement, probably related to being 3D. There is no cudnn bug here. You have a couple of remedies:

  1. Set MXNET_CUDNN_AUTOTUNE_DEFAULT=1. That will result in all convolutions in your model being chosen by cudnnFind(), subject to the limitation that the workspace is less than 1GB. The detrimental fft will be avoided because its workspace is too large (although the model may run slower).
  2. Leave MXNET_CUDNN_AUTOTUNE_DEFAULT=0, but control the 3d convolution locally, e.g. with
    Convolution(..., cudnn_tune='limited_workspace', ...). Only the problem Convolution will have its algo determined by cudnnFind(), subject to a workspace limitation of 1GB. If you don't like the 1GB, then override it locally with e.g. workspace=2048 to set the workspace to 2GB.

There is currently no way to limit algos by workspace size without also running cudnnFind(). We could add this functionality in a backward-compatible way by adding a new supported value to MXNET_CUDNN_AUTOTUNE_DEFAULT:
- Values: 0, 1, 2, or 3 (default=1)
- The default value of cudnn auto tuning for convolution layers.
- Value of 0 means there is no auto tuning to pick the convolution algo
- Performance tests are run to pick the convolution algo when value is 1 or 2
- Value of 1 chooses the best algo in a limited workspace
- Value of 2 chooses the fastest algo whose memory requirements may be larger than the default workspace threshold
- Value of 3 means there is no auto tuning to pick the convolution algo, but the algo cannot have a workspace requirement greater than the limit.

There would be a locally set equivalent to this in the Convolution parameters:
- cudnn_tune='off' # use cudnnGet(), no workspace limit, even if set locally
- cudnn_tune='off_limited_workspace' # use cudnnGet() subject to 1GB or locally-set limit
- cudnn_tune='limited_workspace' # use cudnnFind() subject to 1GB or locally-set limit
- cudnn_tune='fastest' # use cudnnFind(), no workspace limit, even if set locally

While we're at it, I'm not fond of the compiled in default workspace size of 1GB. I'd suggest adding an environment variable:

MXNET_CUDNN_WORKSPACE_LIMIT_DEFAULT # If not set, then limit = 1024 (MB)

@Vikas-kum
Copy link
Contributor Author

Great, Thanks for detailed explanation. Closing this.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

4 participants