Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Stacked RNNCell and FusedRNNCell may give inconsistent results with GPU context #16548

Open
zixuanweeei opened this issue Oct 19, 2019 · 6 comments
Assignees
Labels

Comments

@zixuanweeei
Copy link
Contributor

Description

Some new test cases were added to unit test for covering the fusion routine of RNN operators in #16420. Then we found that flakiness of test_operator.py:test_rnnrelu_sym appears several times from the online CI of either Unix-GPU MKLDNN+GPU or Unix-GPU NOMKLDNN+GPU. We have no idea the root cause of the flakiness, but we can reproduce the inconsistent results locally. Please see the below parts of the details.

Environment info (Required)

----------Python Info----------
Version      : 3.7.3
Compiler     : GCC 7.3.0
Build        : ('default', 'Mar 27 2019 22:11:17')
Arch         : ('64bit', '')
------------Pip Info-----------
Version      : 19.1.1
Directory    : /root/miniconda3/lib/python3.7/site-packages/pip
----------MXNet Info-----------
Version      : 1.6.0
Directory    : /root/dev/incubator-mxnet/python/mxnet
Commit hash file "/root/dev/incubator-mxnet/python/mxnet/COMMIT_HASH" not found. Not installed from pre-built package or built from source.
Library      : ['/root/dev/incubator-mxnet/lib/libmxnet.so', '/root/dev/incubator-mxnet/python/mxnet/../../lib/libmxnet.so']
Build features:
✔ CUDA
✔ CUDNN
✖ NCCL
✔ CUDA_RTC
✖ TENSORRT
✔ CPU_SSE
✔ CPU_SSE2
✔ CPU_SSE3
✔ CPU_SSE4_1
✔ CPU_SSE4_2
✖ CPU_SSE4A
✔ CPU_AVX
✖ CPU_AVX2
✔ OPENMP
✖ SSE
✔ F16C
✖ JEMALLOC
✖ BLAS_OPEN
✖ BLAS_ATLAS
✔ BLAS_MKL
✖ BLAS_APPLE
✖ LAPACK
✖ MKLDNN
✔ OPENCV
✖ CAFFE
✖ PROFILER
✖ DIST_KVSTORE
✖ CXX14
✖ INT64_TENSOR_SIZE
✖ SIGNAL_HANDLER
✖ DEBUG
✖ TVM_OP
----------System Info----------
Platform     : Linux-4.18.0-15-generic-x86_64-with-debian-buster-sid
system       : Linux
node         : d64ced67d422
release      : 4.18.0-15-generic
version      : #16~18.04.1-Ubuntu SMP Thu Feb 7 14:06:04 UTC 2019

Package used (Python/R/Scala/Julia):
I'm using Python Package

Build info (Required if built from source)

Using built-in specs.
COLLECT_GCC=gcc
COLLECT_LTO_WRAPPER=/usr/lib/gcc/x86_64-linux-gnu/7/lto-wrapper
OFFLOAD_TARGET_NAMES=nvptx-none
OFFLOAD_TARGET_DEFAULT=1
Target: x86_64-linux-gnu
Configured with: ../src/configure -v --with-pkgversion='Ubuntu 7.4.0-1ubuntu1~18.04.1' --with-bugurl=file:///usr/share/doc/gcc-7/README.Bugs --enable-languages=c,ada,c++,go,brig,d,fortran,objc,obj-c++ --prefix=/usr --with-gcc-major-version-only --program-suffix=-7 --program-prefix=x86_64-linux-gnu- --enable-shared --enable-linker-build-id --libexecdir=/usr/lib --without-included-gettext --enable-threads=posix --libdir=/usr/lib --enable-nls --with-sysroot=/ --enable-clocale=gnu --enable-libstdcxx-debug --enable-libstdcxx-time=yes --with-default-libstdcxx-abi=new --enable-gnu-unique-object --disable-vtable-verify --enable-libmpx --enable-plugin --enable-default-pie --with-system-zlib --with-target-system-zlib --enable-objc-gc=auto --enable-multiarch --disable-werror --with-arch-32=i686 --with-abi=m64 --with-multilib-list=m32,m64,mx32 --enable-multilib --with-tune=generic --enable-offload-targets=nvptx-none --without-cuda-driver --enable-checking=release --build=x86_64-linux-gnu --host=x86_64-linux-gnu --target=x86_64-linux-gnu
Thread model: posix
gcc version 7.4.0 (Ubuntu 7.4.0-1ubuntu1~18.04.1)

MXNet commit hash:
63fbfb1

Build config:

make -j10 USE_PROFILER=0 USE_CUDA=1 USE_CUDNN=1 USE_MKLDNN=0 USE_BLAS=mkl USE_INTEL_PATH=/opt/intel USE_CUDA_PATH=/usr/local/cuda-10.0 USE_CUDNN_PATH=/usr/local/cuda-10.0/lib64 test

Error Message:

/root/dev/incubator-mxnet/python/mxnet/rnn/rnn_cell.py:675: UserWarning: NTC layout detected. Consider using TNC for FusedRNNCell for faster speed
  warnings.warn("NTC layout detected. Consider using "
Traceback (most recent call last):
  File "rnn_relu_unidirectional.py", line 69, in <module>
    check_consistency(fused_cell, stacked_cell, seq_len, batch_size, input_dim, "write")
  File "rnn_relu_unidirectional.py", line 62, in check_consistency
    assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=rtol, atol=atol)
  File "/root/miniconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1501, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/root/miniconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 827, in assert_array_compare
    raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=0.01, atol=0.0001

Mismatch: 0.0195%
Max absolute difference: 0.00012941
Max relative difference: 1.6247102
 x: array([[[ 2.050658e-03,  1.623387e-03,  1.420917e-03, ...,
         -5.661430e-05,  1.032020e-03,  1.353525e-03],
        [ 2.300234e-05,  1.259212e-03,  1.050305e-03, ...,...
 y: array([[[ 2.050658e-03,  1.623387e-03,  1.420917e-03, ...,
         -5.661434e-05,  1.032020e-03,  1.353525e-03],
        [ 2.300218e-05,  1.259212e-03,  1.050305e-03, ...,...

Minimum reproducible example

import mxnet as mx
import numpy as np
from numpy.testing import assert_allclose
from mxnet.test_utils import set_default_context, default_context


def sym_gen(seq_len, batch_size, input_dim, state_dim):
  fused = mx.rnn.FusedRNNCell(state_dim, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='')
  stack = mx.rnn.SequentialRNNCell()
  stack.add(mx.rnn.RNNCell(state_dim, activation='relu', prefix='l0_'))
  stack.add(mx.rnn.RNNCell(state_dim, activation='relu', prefix='l1_'))
  stack.add(mx.rnn.RNNCell(state_dim, activation='relu', prefix='l2_'))

  return fused, stack


def check_consistency(cell1, cell2, seq_len, batch_size, input_dim, grad_req, atol=1e-4, rtol=1e-2):
  dshape = (batch_size, seq_len, input_dim)
  data = mx.sym.Variable('data')

  Y1, _ = cell1.unroll(seq_len, data, layout='NTC', merge_outputs=True)
  mod1 = mx.mod.Module(Y1, label_names=None, context=default_context())
  mod1.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True, grad_req=grad_req)

  Y2, _ = cell2.unroll(seq_len, data, layout='NTC', merge_outputs=True)
  mod2 = mx.mod.Module(Y2, label_names=None, context=default_context())
  mod2.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True, grad_req=grad_req)

  mod1.init_params()
  _, auxs = mod1.get_params()
  # args = cell1.unpack_weights(args)
  npzfile = np.load("./issue_array.npz")
  arrays = dict([(name, npzfile[name]) for name in npzfile.files])
  for name, array in arrays.items():
      arrays[name] = mx.nd.array(array)
  args1 = cell1.pack_weights(arrays)
  mod1.set_params(args1, auxs)
  args2 = cell2.pack_weights(arrays)
  mod2.set_params(args2, auxs)

  x = np.load("./x.npz")["x"]
  x = mx.nd.array(x)
  batch=mx.io.DataBatch(data=[x])
  # check inference
  mod1.forward(batch, is_train=False)
  mod2.forward(batch, is_train=False)
  assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=rtol, atol=atol)

  # check training
  mod1.forward(batch, is_train=True)
  mod2.forward(batch, is_train=True)
  assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=rtol, atol=atol)

  dy_array = np.load("./dy.npz")["dy"]
  dy = mx.nd.array(dy_array)
  mod1.backward(out_grads=[dy])
  mod2.backward(out_grads=[dy])
  if type(grad_req) is dict and grad_req['data'] == 'null' or grad_req == 'null':
    assert(mod1.get_input_grads()[0] == None)
    assert(mod2.get_input_grads()[0] == None)
  else:
    assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=rtol, atol=atol)


if __name__ == "__main__":
  set_default_context(mx.gpu(0))
  seq_len, batch_size, input_dim, state_dim = 5, 32, 32, 512
  fused_cell, stacked_cell = sym_gen(seq_len, batch_size, input_dim, state_dim)
  check_consistency(fused_cell, stacked_cell, seq_len, batch_size, input_dim, "write")

Steps to reproduce

  1. Please download the input, weights, gradients data from the link below. It contains three .npz files.
  2. Extract them to the directory of the Python script above
  3. Install mxnet-gpu and execute the script

data.tar.gz

@zixuanweeei
Copy link
Contributor Author

@pengzhao-intel @TaoLv

@pengzhao-intel
Copy link
Contributor

How relative error is executed here which will cause the flaky report if the absolute number is very small?
The absolute difference is only 0.00012941 but relative difference is large than 1. Please check how the calculate and maybe we need a robust algorithm for the relative difference.

@zixuanweeei
Copy link
Contributor Author

Thanks for your reply. @pengzhao-intel

According to Numpy, the absolute and relative differences are derived by

error = abs(x - y)
max_abs_error = error.max()
max_rel_error = (error / abs(y)).max()

These two are used in the error message. The criterion of assert_allclose(actual, desired, rtol, atol) is atol + rtol * abs(desired).

Below is an example of a failed test_rnnrelu_sym on Windows-GPU,

*** Maximum errors for vector of size 20480:  rtol=0.01, atol=0.0001

  1: Error 1.192976  Location of error: (3, 0, 103), a=0.00010552, b=0.00022754

We can see that a and b have different values.

@ddavydenko
Copy link
Contributor

@mxnet-label-bot add [Bug]

@lanking520 lanking520 added the Bug label Oct 21, 2019
@samskalicky
Copy link
Contributor

@lanking520 assign @apeforest
@zixuanweeei any update on this issue?

@zixuanweeei
Copy link
Contributor Author

@lanking520 assign @apeforest
@zixuanweeei any update on this issue?

I just moved the unit test for rnn_relu to its original look due to the urgency of transition to mkldnn-v1.0. I have tested it on local. Some flakiness appeared with this unit test. It may pass sometimes. I am mot familiar with the GPU platform so that it needs others from their expertise.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

6 participants