Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

added support for large tensors for Dropout operator and tests to verify support for more operators #16409

Merged
merged 3 commits into from
Oct 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ class DropoutOp {
* \param input_data Input data to perform the dropout on
* \param pkeep Dropout rate (keep when the generated random number is less than this value)
*/
MSHADOW_XINLINE static void Map(int id,
MSHADOW_XINLINE static void Map(index_t id,
RandGenerator<xpu, DType> gen,
const int N,
const int step,
const index_t N,
const index_t step,
DType *dropout_out,
DType *mask_out,
const DType *input_data,
Expand All @@ -199,10 +199,10 @@ class DropoutOp {
};
struct BernoulliKernel {
/*! \brief Bernoulli kernel for generating mask */
MSHADOW_XINLINE static void Map(int id,
MSHADOW_XINLINE static void Map(index_t id,
RandGenerator<xpu, DType> gen,
const int N,
const int step,
const index_t N,
const index_t step,
DType *mask_out,
const real_t pkeep) {
RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, {
Expand Down
5 changes: 3 additions & 2 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,10 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet:
*new_oshape = mxnet::TShape(odim, 1);
int bl = oshape.ndim() - lshape.ndim();
int br = oshape.ndim() - rshape.ndim();
int j = 0, lprod = 1, rprod = 1, oprod = 1;
int j = 0;
index_t lprod = 1, rprod = 1, oprod = 1;
for (int i = 0; i < oshape.ndim(); ++i) {
int l = 1, r = 1, o = oshape[i];
index_t l = 1, r = 1, o = oshape[i];
if (i >= bl) l = lshape[i-bl];
if (i >= br) r = rshape[i-br];
if ((lprod != rprod || l != r) &&
Expand Down
240 changes: 240 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,246 @@ def test_full():
assert a[-1][-1] == 3


def test_astype():
x = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
y = x.astype('int32')
assert y.dtype == np.int32
assert y[-1][-1] == SMALL_Y-1


def test_cast():
x = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
y = nd.cast(x, np.int32)
assert y.dtype == np.int32
assert y[-1][-1] == SMALL_Y-1


def test_repeat():
x = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X//2)
y = nd.repeat(x, repeats=2, axis = 1)
assert y.shape == (SMALL_Y, LARGE_X)
assert y[0][1] == 0
assert y[-1][-1] == SMALL_Y-1
x = create_2d_tensor(rows=SMALL_Y//2, columns=LARGE_X)
y = nd.repeat(x, repeats=2, axis = 0)
assert y.shape == (SMALL_Y, LARGE_X)
assert y[0][1] == 0
assert y[-1][0] == SMALL_Y//2-1


def create_input_for_rounding_ops():
# Creates an vector with values (-LARGE_X/2 .... -2, -1, 0, 1, 2, .... , LARGE_X/2-1)
# then divides each element by 2 i.e (-LARGE_X/4 .... -1, -0.5, 0, 0.5, 1, .... , LARGE_X/4-1)
# and finally broadcasts to
inp = nd.arange(-LARGE_X//2, LARGE_X//2, dtype=np.float64).reshape(1, LARGE_X)
inp = inp/2
inp = nd.broadcast_to(inp, (SMALL_Y, LARGE_X))
return inp


def assert_correctness_of_rounding_ops(output, mid, expected_vals):
# checks verifies 5 values at the middle positions of the input vector
# i.e mid-2, mid-1, mid, mid+1, mid+2
output_idx_to_inspect = [mid-2, mid-1, mid, mid+1, mid+2]
for i in range(len(output_idx_to_inspect)):
assert output[1][output_idx_to_inspect[i]] == expected_vals[i]


# TODO(access2rohit): merge similar tests in large vector and array into one file.
def test_rounding_ops():
access2rohit marked this conversation as resolved.
Show resolved Hide resolved
x = create_input_for_rounding_ops()

def check_ceil():
y = nd.ceil(x)
# expected ouput for middle 5 values after applying ceil()
expected_output = [-1, 0, 0, 1, 1]
assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)

def check_fix():
y = nd.fix(x)
# expected ouput for middle 5 values after applying fix()
expected_output = [-1, 0, 0, 0, 1]
assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)

def check_floor():
y = nd.floor(x)
# expected ouput for middle 5 values after applying floor()
expected_output = [-1, -1, 0, 0, 1]
assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)

def check_rint():
y = nd.rint(x)
# expected ouput for middle 5 values after applying rint()
expected_output = [-1, -1, 0, 0, 1]
assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)

def check_round():
y = nd.round(x)
# expected ouput for middle 5 values after applying round()
expected_output = [-1, -1, 0, 1, 1]
assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)

def check_trunc():
y = nd.trunc(x)
# expected ouput for middle 5 values after applying trunc()
expected_output = [-1, 0, 0, 0, 1]
assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)

check_ceil()
check_fix()
check_floor()
check_rint()
check_round()
check_trunc()


def create_input_for_trigonometric_ops(vals):
# Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using tile operator
inp = nd.array(vals).reshape(1, 5)
inp = nd.broadcast_to(inp, (LARGE_X*10, SMALL_Y//10))
return inp


def assert_correctness_of_trigonometric_ops(output, expected_vals):
# checks verifies 5 values at positions(0, 1, -3, -2, -1) of the input vector
output_idx_to_inspect = [0, 1, -3, -2, -1]
for i in range(len(output_idx_to_inspect)):
assert np.abs(output[1][output_idx_to_inspect[i]].asnumpy()-expected_vals[i]) <= 1e-3


def test_trigonometric_ops():
def check_arcsin():
x = create_input_for_trigonometric_ops([-1, -.707, 0, .707, 1])
y = nd.arcsin(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying arcsin()
expected_output = [-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_arccos():
x = create_input_for_trigonometric_ops([-1, -.707, 0, .707, 1])
y = nd.arccos(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying arccos()
expected_output = [np.pi, 3*np.pi/4, np.pi/2, np.pi/4, 0]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_arctan():
x = create_input_for_trigonometric_ops([-np.Inf, -1, 0, 1, np.Inf])
y = nd.arctan(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying arctan()
expected_output = [-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_sin():
x = create_input_for_trigonometric_ops([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2])
y = nd.sin(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying sin()
expected_output = [-1, -.707, 0, .707, 1]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_cos():
x = create_input_for_trigonometric_ops([0, np.pi/4, np.pi/2, 3*np.pi/4, np.pi])
y = nd.cos(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying cos()
expected_output = [1, .707, 0, -.707, -1]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_tan():
x = create_input_for_trigonometric_ops([-np.pi/6, -np.pi/4, 0, np.pi/4, np.pi/6])
y = nd.tan(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying tan()
expected_output = [-.577, -1, 0, 1, .577]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_radians():
x = create_input_for_trigonometric_ops([0, 90, 180, 270, 360])
y = nd.radians(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying radians()
expected_output = [0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi]
assert_correctness_of_trigonometric_ops(y, expected_output)

def check_degrees():
x = create_input_for_trigonometric_ops([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi])
y = nd.degrees(x)
# expected ouput for indices=(0, 1, -3, -2, -1) after applying degrees()
expected_output = [0, 90, 180, 270, 360]
assert_correctness_of_trigonometric_ops(y, expected_output)

check_arcsin()
check_arccos()
check_arctan()
check_sin()
check_cos()
check_tan()
check_radians()
check_degrees()


def test_L2Normalization():
x = nd.ones((2, LARGE_X*2))
x[0] = 3
x[1] = 4
# Channel Mode
z = x.reshape(1, 2, LARGE_X*2)
y = nd.L2Normalization(z, mode='channel')
assert y[0][0][0] == 0.6
assert y[0][0][-1] == 0.6
assert y[0][1][0] == 0.8
assert y[0][1][-1] == 0.8
# Instance Mode
z = x.T
y = nd.L2Normalization(z, mode='instance')
assert y[0][0] == 0.6
assert y[0][1] == 0.8
assert y[-1][0] == 0.6
assert y[-1][1] == 0.8
# Spatial Mode
z = z.reshape(1, 200000000, 2)
y = nd.L2Normalization(z, mode='spatial')
assert y[0][0][0] == 0.6
assert y[0][0][1] == 0.8
assert y[0][-1][0] == 0.6
assert y[0][-1][1] == 0.8


def test_instance_norm():
dtype = np.float32
forward_check_eps = 1E-3
axis = -1
eps = 1E-5
in_shape = (LARGE_X, 1, SMALL_Y)
ctx = mx.cpu()

# Implementation of instance normalization using numpy
def npy_instance_norm(data, gamma, beta, axis, eps=1E-5):
if axis < 0:
axis += data.ndim
broadcast_shape = [1 for _ in range(data.ndim)]
broadcast_shape[axis] = data.shape[axis]
mean = data.mean(axis=axis, keepdims=True).astype(dtype)
var = data.var(axis=axis, keepdims=True).astype(dtype)
std = np.sqrt(var + dtype(eps)).astype(dtype)
out = gamma * (data - mean) / std + \
beta
return out
Comment on lines +1412 to +1422
Copy link
Contributor Author

@access2rohit access2rohit Oct 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data = np.random.normal(0, 1, in_shape).astype(dtype)
gamma = np.random.normal(0, 1, (1,)).astype(dtype)
beta = np.random.normal(0, 1, (1,)).astype(dtype)
data_s = mx.symbol.Variable('data')
gamma_s = mx.symbol.Variable('gamma')
beta_s = mx.symbol.Variable('beta')
out_s = mx.symbol.InstanceNorm(data=data_s, gamma=gamma_s, beta=beta_s,
eps=eps)
exe = out_s.simple_bind(ctx, data=in_shape)
access2rohit marked this conversation as resolved.
Show resolved Hide resolved
exe.arg_dict['data'][:] = data
exe.arg_dict['gamma'][:] = gamma
exe.arg_dict['beta'][:] = beta
out_nd = exe.forward()[0]
# Calls implementation of instance norm in numpy and compares the output
out = npy_instance_norm(data, gamma, beta, axis, eps)
assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps,
forward_check_eps)


if __name__ == '__main__':
import nose
nose.runmodule()