diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index fee534315b77..37af908042c9 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -124,16 +124,22 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const NumpyTransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace"; + if (req[0] == kNullOp) return; + CHECK(req[0] == kWriteTo || req[0] == kAddTo) + << "Transpose only supports kWriteTo, kNullOp and kAddTo"; + mxnet::TShape axes; if (ndim_is_known(param.axes)) { - mxnet::TShape axes = common::CanonicalizeAxes(param.axes); - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + axes = common::CanonicalizeAxes(param.axes); } else { - mxnet::TShape axes(inputs[0].ndim(), -1); + axes = mxnet::TShape(inputs[0].ndim(), -1); for (int i = 0; i < axes.ndim(); ++i) { axes[i] = axes.ndim() - 1 - i; } - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + } + if (req[0] == kAddTo) { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + } else { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); } } diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index e496202a0b41..05b7e948e8a0 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -24,6 +24,7 @@ */ #include +#include #include "./np_matrix_op-inl.h" #include "../nn/concat-inl.h" @@ -67,8 +68,13 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, mxnet::TShape ret(ndim, -1); if (ndim_is_known(param.axes)) { - CHECK_EQ(ndim, param.axes.ndim()); + CHECK_EQ(ndim, param.axes.ndim()) + << "The number of axes does not match the dimension of the tensor. axes = " + << param.axes << ", input tensor shape = " << shp; mxnet::TShape axes = common::CanonicalizeAxes(param.axes); + std::set axes_set(axes.begin(), axes.end()); + CHECK_EQ(axes_set.size(), axes.ndim()) << "Repeated axis in transpose. param.axes = " + << param.axes; if (ndim_is_known(shp)) { for (int i = 0; i < ndim; ++i) { ret[i] = shp[axes[i]]; @@ -117,9 +123,9 @@ NNVM_REGISTER_OP(_np_transpose) } std::ostringstream os; os << axes; - return MakeNonlossGradNode("transpose", n, ograds, {}, {{"axes", os.str()}}); + return MakeNonlossGradNode("_np_transpose", n, ograds, {}, {{"axes", os.str()}}); } else { - return MakeNonlossGradNode("transpose", n, ograds, {}, + return MakeNonlossGradNode("_np_transpose", n, ograds, {}, std::unordered_map()); } }) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 0fee2a26c0ed..4bd059ae81df 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -269,8 +269,10 @@ struct TransposeParam : public dmlc::Parameter { * \param out output tensor * \param row shape of dim 0 of input * \param col shape of dim 1 of input + * \tparam DType Data type + * \tparam is_addto */ -template +template MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) { // ensure cache line hits and prevent cache miss for any configuration // L1 cache size to be utilized = 32kb = 2^15 @@ -282,7 +284,7 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index // Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled // blocksize * blocksize * num_threads = cache_size / dtype_size // Instead of explicit unroll, let compiler figure out optimal unroll factor - index_t blocksize = 32; + const index_t blocksize = 32; // collapse 2 parallelizes 2 for loops // inner 2 for loops aren't parallelized to prevent cache miss @@ -299,14 +301,25 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index // transpose the block for (index_t a = j; (a < blocksize + j) && (a < col); ++a) { for (index_t b = i; (b < blocksize + i) && (b < row); ++b) { - out[a * row + b] = in[b * col + a]; + if (!is_addto) { + out[a * row + b] = in[b * col + a]; + } else { + out[a * row + b] += in[b * col + a]; + } } } } } } -template +inline bool IsIdentityTranspose(const TShape& axes) { + for (dim_t i = 0; i < axes.ndim(); i++) { + if (axes[i] != i) return false; + } + return true; +} + +template void TransposeImpl(RunContext ctx, const TBlob& src, const TBlob& ret, @@ -323,62 +336,79 @@ void TransposeImpl(RunContext ctx, // Example: (0, 2, 3, 1) or (0, 3, 1, 2), but not (0, 2, 1, 3). if (isPseudo2DTranspose(axes)) { MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { - transpose_pseudo2D(ret, src, axes, s); + transpose_pseudo2D(ret, src, axes, s); }); return; } #endif + // Special handle the identity case + if (IsIdentityTranspose(axes)) { + MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { + Tensor in = src.get_with_shape(mshadow::Shape1(src.Size()), s); + Tensor out = ret.get_with_shape(mshadow::Shape1(ret.Size()), s); + if (!is_addto) { + // Use memcpy to accelerate the speed + Copy(out, in, s); + } else { + mxnet_op::Kernel, xpu>::Launch( + s, ret.Size(), out.dptr_, in.dptr_); + } + }); + return; + } + // Handle the general transpose case MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { switch (axes.ndim()) { - case 0: { - Tensor in = src.get_with_shape(mshadow::Shape1(1), s); - Tensor out = ret.get_with_shape(mshadow::Shape1(1), s); - Copy(out, in, s); - break; - } - case 1: { - Tensor in = src.get(s); - Tensor out = ret.get(s); - Copy(out, in, s); - break; - } case 2: { - mshadow::Tensor in = src.FlatTo2D(s); - mshadow::Tensor out = ret.FlatTo2D(s); - - if (axes[0] == 1 && axes[1] == 0) { - if (ctx.get_ctx().dev_mask() == cpu::kDevMask) { - Transpose2D(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); - } else { - out = in.T(); - } + Tensor in = src.get(s); + Tensor out = ret.get(s); + if (ctx.get_ctx().dev_mask() == cpu::kDevMask) { + Transpose2D(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); } else { - Copy(out, in, s); + LOG(FATAL) << "Not Implemented. We should never reach here because the 2D case " + "in GPU has been covered by transpose_pseudo2D." + " Report an issue in Github."; } break; } case 3: { Tensor in = src.get(s); Tensor out = ret.get(s); - out = transpose(in, axes.get<3>()); + if (!is_addto) { + out = transpose(in, axes.get<3>()); + } else { + out += transpose(in, axes.get<3>()); + } break; } case 4: { Tensor in = src.get(s); Tensor out = ret.get(s); - out = transpose(in, axes.get<4>()); + if (!is_addto) { + out = transpose(in, axes.get<4>()); + } else { + out += transpose(in, axes.get<4>()); + } break; } case 5: { Tensor in = src.get(s); Tensor out = ret.get(s); - out = transpose(in, axes.get<5>()); + if (!is_addto) { + out = transpose(in, axes.get<5>()); + } else { + out += transpose(in, axes.get<5>()); + } break; } case 6: { Tensor in = src.get(s); Tensor out = ret.get(s); - out = transpose(in, axes.get<6>()); + if (!is_addto) { + out = transpose(in, axes.get<6>()); + } else { + out += transpose(in, axes.get<6>()); + } break; } default: @@ -399,15 +429,21 @@ void Transpose(const nnvm::NodeAttrs& attrs, return; } const TransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo"; + CHECK(req[0] == kWriteTo || req[0] == kAddTo) + << "Transpose only supports kNullOp, kWriteTo and kAddTo"; + mxnet::TShape axes; if (param.axes.ndim() == 0) { - mxnet::TShape axes(inputs[0].ndim(), -1); + axes = mxnet::TShape(inputs[0].ndim(), -1); for (int i = 0; i < axes.ndim(); ++i) { axes[i] = axes.ndim() - 1 - i; } - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); } else { - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], param.axes); + axes = common::CanonicalizeAxes(param.axes); + } + if (req[0] == kAddTo) { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + } else { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); } } diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 5b702fbaa2d6..15b954f11c1d 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -283,11 +283,12 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs, return; } const TransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo"; + CHECK(req[0] == kWriteTo || req[0] == kAddTo) << + "Transpose only supports kNullOp, kWriteTo and kAddTo"; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); - if (SupportMKLDNNTranspose(param, inputs[0])) { + if (SupportMKLDNNTranspose(param, inputs[0]) && req[0] == kWriteTo) { MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]); return; } diff --git a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh index 5b7cf04daef4..b3ca9fbfa0c9 100644 --- a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh +++ b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh @@ -39,22 +39,31 @@ namespace mxnet { namespace op { namespace cuda { - -template +/*! + * \brief The `transpose_pseudo2D` based on chosen vectorized types. It transposes an array of + * shape (k, m, n) to (k, n, m) + * \param out Pointer to output memory. + * \param inp Pointer to input memory. + * \param m First of tensor dimensions. + * \param n Second of tensor dimensions. + * \param nIterY The number of iterations in the y-dim of the thread to cover all rows. (1-->m) + * \param nIterZ The number of iterations in the z-dim of the thread to cover all rows. (1-->k) + * \tparam DType Data type + * \tparam CType The type to load the data. + * \tparam is_addto Whether to perform out += transpose(data) or out = transpose(data) + */ +template __global__ void transpose_pseudo2D(DType* out, DType* inp, const index_t m, const index_t n, const index_t nIterY, const index_t nIterZ) { - const index_t TSR = sizeof(CType)/sizeof(DType); // TypeSizeRatio + // Calculate the TypeSizeRatio + const index_t TSR = sizeof(CType) / sizeof(DType) > 0 ? sizeof(CType) / sizeof(DType) : 1; const index_t chunked_n = n/TSR; const index_t chunked_m = m/TSR; - union transp_t { - CType valChunk; - DType values[TSR]; - }; - - __shared__ DType d_shm[1024*TSR*TSR]; - CType* c_shm = reinterpret_cast(d_shm); + extern __shared__ char buf[]; + DType* d_shm = reinterpret_cast(buf); + CType* c_shm = reinterpret_cast(buf); CType* cInp = reinterpret_cast(inp); CType* cOut = reinterpret_cast(out); @@ -78,23 +87,34 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp, } __syncthreads(); - // read from shared to registers - transp_t tmp[TSR]; + // read from shared to local registers + CType tmp[TSR]; #pragma unroll for (index_t i = 0; i < TSR; i++) { + DType* tmp_dptr = reinterpret_cast(&tmp[i]); #pragma unroll for (int j = 0; j < TSR; j++) { index_t shmIdx = (TSR*threadIdx.y + j)*blockDim.x*TSR + TSR*threadIdx.x + i; - tmp[i].values[j] = d_shm[shmIdx]; + tmp_dptr[j] = d_shm[shmIdx]; } } __syncthreads(); // write back to global output - offset = blockIdx_z*m*chunked_n + blockIdx.x*blockDim.x*TSR*chunked_m + blockIdx_y*blockDim.y; + offset = blockIdx_z*m*chunked_n + blockIdx.x*blockDim.x*TSR*chunked_m + + blockIdx_y*blockDim.y; #pragma unroll for (index_t i = 0; i < TSR; i++) { - cOut[offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y] = tmp[i].valChunk; + if (is_addto) { + DType* tmp_dptr = reinterpret_cast(&tmp[i]); + #pragma unroll + for (int j = 0; j < TSR; j++) { + out[TSR * (offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y) + j] + += tmp_dptr[j]; + } + } else { + cOut[offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y] = tmp[i]; + } } } } @@ -107,7 +127,6 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp, /*! * \brief Calls proper version of kernel `transpose_pseudo2D` * basing on chosen type sizes. - * \param dTypeSize Size of data type. * \param cTypeSize Size of type that should be use to copy. * \param grid Grid dimensions for the kernel. * \param block Block dimensions for the kernel. @@ -116,92 +135,39 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp, * \param inp Pointer to input memory. * \param m First of tensor dimensions. * \param n Second of tensor dimensions. + * \tparam DType Data type + * \tparam is_addto Whether to trigger add the transpose result to the output tensor. */ -inline void call_transpose_pseudo2D(index_t dTypeSize, index_t cTypeSize, - dim3 grid, dim3 block, cudaStream_t stream, - void* out, void* inp, const index_t m, const index_t n, - const index_t nIterY, const index_t nIterZ) { - switch (dTypeSize) { - case (1): { - uint8_t* d_outPtr = reinterpret_cast(out); - uint8_t* d_inpPtr = reinterpret_cast(inp); - switch (cTypeSize) { - case (1): - cuda::transpose_pseudo2D<<>> +template +inline void call_transpose_pseudo2D(index_t cTypeSize, + dim3 grid, dim3 block, cudaStream_t stream, + DType* d_outPtr, DType* d_inpPtr, + const index_t m, const index_t n, + const index_t nIterY, const index_t nIterZ) { + const int nshared = 1024 * cTypeSize / sizeof(DType) * cTypeSize; + switch (cTypeSize) { + case (1): + cuda::transpose_pseudo2D<<>> (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); break; - case (2): - cuda::transpose_pseudo2D<<>> + case (2): + cuda::transpose_pseudo2D<<>> (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); break; - case (4): - cuda::transpose_pseudo2D<<>> + case (4): + cuda::transpose_pseudo2D<<>> (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); break; - case (8): - // case guarded against in function getBestCopyTypeSize - LOG(FATAL) << "cuda::transpose_pseudo2D would take too much shared memory"; - default: - LOG(FATAL) << "Unsupported type combination"; - } - break; - } - case (2): { - uint16_t* d_outPtr = reinterpret_cast(out); - uint16_t* d_inpPtr = reinterpret_cast(inp); - switch (cTypeSize) { - case (2): - cuda::transpose_pseudo2D<<>> + case (8): + cuda::transpose_pseudo2D<<>> (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); break; - case (4): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - case (8): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - default: - LOG(FATAL) << "Unsupported type combination"; - } - break; - } - case (4): { - uint32_t* d_outPtr = reinterpret_cast(out); - uint32_t* d_inpPtr = reinterpret_cast(inp); - switch (cTypeSize) { - case (4): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - case (8): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - default: - LOG(FATAL) << "Unsupported type combination"; - } - break; - } - case (8): { - uint64_t* d_outPtr = reinterpret_cast(out); - uint64_t* d_inpPtr = reinterpret_cast(inp); - switch (cTypeSize) { - case (8): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - default: - LOG(FATAL) << "Unsupported type combination"; - } - break; - } - default: - LOG(FATAL) << "Unsupported type combination"; + default: + LOG(FATAL) << "Unsupported type combination. " << "Copy type size = " << cTypeSize; } auto cuErr = cudaPeekAtLastError(); - CHECK_EQ(cuErr, cudaSuccess) << "Transpose kernel failure: " << cudaGetErrorString(cuErr) << ". " + CHECK_EQ(cuErr, cudaSuccess) << "TransposePseudo2D kernel failure: " + << cudaGetErrorString(cuErr) << ". " << "block: (" << block.x << "," << block.y << "," << block.z << ")" << " grid: (" << grid.x << "," << grid.y << "," << grid.z << ")"; } @@ -225,7 +191,6 @@ inline bool isPseudo2DTranspose(const TShape& params) { return n_swpDims == 2; } - struct pseudo2DSizes { index_t leadDimS; index_t M; @@ -306,15 +271,14 @@ inline std::pair calculateKernelParams(pseudo2DSizes sizes, const in * \param outBlob Tensor blob to store result. * \param inpBlob Tensor blob with input data. * \param params Parameters (axes) of the transpose. + * \param is_addto Whether to add the transpose result to the outBlob * \param s Pointer to GPU stream. */ -template +template void transpose_pseudo2D(const TBlob& outBlob, const TBlob& inpBlob, const TShape& params, mshadow::Stream* s) { const TShape& shape = inpBlob.shape_; CHECK_EQ(shape.ndim(), params.ndim()); - auto ndim = params.ndim(); - auto sizes = getPackedTransposeDimensions(shape, params); index_t cTypeSize = getBestCopyTypeSize(sizeof(DType), sizes.M, sizes.N); @@ -337,8 +301,10 @@ void transpose_pseudo2D(const TBlob& outBlob, const TBlob& inpBlob, } cudaStream_t stream = mshadow::Stream::GetStream(s); - call_transpose_pseudo2D(sizeof(DType), cTypeSize, grid, block, stream, - outBlob.dptr_, inpBlob.dptr_, sizes.M, sizes.N, nIterY, nIterZ); + call_transpose_pseudo2D + (cTypeSize, grid, block, stream, + outBlob.dptr(), inpBlob.dptr(), + sizes.M, sizes.N, nIterY, nIterZ); } } // namespace op diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9b7f7036bcda..2a44385b8a6d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1304,7 +1304,9 @@ def np_transpose_grad(out_shape, dtype, axes=None): if axes is None or axes == (): return _np.transpose(ograd, axes) np_axes = _np.array(list(axes)) - return _np.transpose(ograd, tuple(list(_np.argsort(np_axes)))) + transpose_axes = _np.zeros_like(np_axes) + transpose_axes[np_axes] = _np.arange(len(np_axes)) + return _np.transpose(ograd, tuple(list(transpose_axes))) class TestTranspose(HybridBlock): def __init__(self, axes=None): @@ -1313,45 +1315,57 @@ def __init__(self, axes=None): def hybrid_forward(self, F, a): return F.np.transpose(a, self.axes) + test_workloads = [[(), [(), None]], + [(2,), [(0,), None]], + [(0, 2), [(0, 1), (1, 0)]], + [(5, 10), [(0, 1), (1, 0), None]], + [(8, 2, 3), [(2, 0, 1), (0, 2, 1), (0, 1, 2), (2, 1, 0), (-1, 1, 0), None]], + [(8, 2, 16), [(0, 2, 1), (2, 0, 1), (0, 1, 2), (2, 1, 0), (-1, -2, -3)]], + [(8, 3, 4, 8), [(0, 2, 3, 1), (1, 2, 3, 0), (0, 3, 2, 1)]], + [(8, 3, 2, 3, 8), [(0, 1, 3, 2, 4), (0, 1, 2, 3, 4), (4, 0, 1, 2, 3)]], + [(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]]] for hybridize in [True, False]: - for dtype in [_np.int32, _np.float32]: - for ndim in range(7): - shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True) - axeses = [None] - if ndim == 0: - axeses += [()] - else: - axes = [i for i in range(ndim)] - axeses.append(tuple(axes)) - random.shuffle(axes) - axeses.append(tuple(axes)) - axeses.append([i - len(axes) for i in axes]) - for axes in axeses: - test_trans = TestTranspose(axes) - if hybridize: - test_trans.hybridize() - x = rand_ndarray(shape).as_np_ndarray() - x = x.astype(dtype) - x.attach_grad() - np_out = _np.transpose(x.asnumpy(), axes) - with mx.autograd.record(): - mx_out = test_trans(x) - assert mx_out.shape == np_out.shape - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) - mx_out.backward() - np_backward = np_transpose_grad(np_out.shape, dtype, axes) - assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False) - - mx_out = x.transpose(axes) - np_out = x.asnumpy().transpose(axes) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + for dtype in [_np.float32, _np.float16, _np.int32]: + for data_shape, axes_workload in test_workloads: + for axes in axes_workload: + for grad_req in ['write', 'add']: + test_trans = TestTranspose(axes) + if hybridize: + test_trans.hybridize() + x = np.random.normal(0, 1, data_shape).astype(dtype) + x = x.astype(dtype) + x.attach_grad(grad_req=grad_req) + if grad_req == 'add': + x.grad[()] = np.random.normal(0, 1, x.grad.shape).astype(x.grad.dtype) + x_grad_np = x.grad.asnumpy() + np_out = _np.transpose(x.asnumpy(), axes) + with mx.autograd.record(): + mx_out = test_trans(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + mx_out.backward() + np_backward = np_transpose_grad(np_out.shape, dtype, axes) + if grad_req == 'add': + assert_almost_equal(x.grad.asnumpy(), np_backward + x_grad_np, + rtol=1e-3, atol=1e-5, use_broadcast=False) + else: + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False) - if isinstance(axes, (list, tuple)): - mx_out = x.transpose(*axes) - np_out = x.asnumpy().transpose(*axes) + mx_out = x.transpose(axes) + np_out = x.asnumpy().transpose(axes) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + if isinstance(axes, (list, tuple)): + mx_out = x.transpose(*axes) + np_out = x.asnumpy().transpose(*axes) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + # Test for error raising + dat = np.random.normal(0, 1, (3, 4, 5), dtype=np.float32) + assert_raises(MXNetError, lambda: dat.transpose((0, 0, 1))) + assert_raises(MXNetError, lambda: dat.transpose((0, 1, 3))) + + @with_seed() @use_np