From 8727cae68fd732dc352be6d1d587f6512252d3b4 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 25 Apr 2018 22:42:29 -0700 Subject: [PATCH] [MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU (#10371) * add support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU * add unit test for new op and forward_stype_hint parameter to dot * update documentation for dot * address code reviews * fix flaky test_gluon:test_lambda through loosening the atol * switch dot(dns, csr) case to a deterministic algorithm with unit test for determinism * address code reviews and add backward --- src/operator/tensor/dot-inl.cuh | 245 ++++++++++++++++++ src/operator/tensor/dot-inl.h | 108 ++++++-- src/operator/tensor/dot.cc | 14 +- src/operator/tensor/util/tensor_util-inl.cuh | 16 ++ tests/python/unittest/test_gluon.py | 4 +- tests/python/unittest/test_sparse_operator.py | 96 ++++++- 6 files changed, 441 insertions(+), 42 deletions(-) diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index c546c4351a28..86df5801c73c 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -27,6 +27,8 @@ #include #include +#include "./init_op.h" +#include "./sort_op.h" #include "./util/tensor_util-inl.h" #include "./util/tensor_util-inl.cuh" @@ -442,6 +444,99 @@ struct DotCsrRspDnsScalarKernel { } }; +/*! + * \brief GPU Kernel to scatter row id to corresponding entries + * \param tid global thread id + * \param csr_indptr indptr array of csr + * \param csr_rows array of row id of csr elements + * \param num_rows total number of rows in csr matrix + * Parallelization by output elements: 1 thread/row + */ +struct CsrRowScatterKernel { + template + __device__ __forceinline__ static void Map(int tid, + const CType* csr_indptr, + CType* csr_rows, + const nnvm::dim_t num_rows) { + if (tid < num_rows) { + for (CType i = csr_indptr[tid]; i < csr_indptr[tid+1]; ++i) { + csr_rows[i] = tid; + } + } + } +}; + +struct CscDataIndicesKernel { + /*! + * \brief + * \param tid global thread id + * \param lhs_data lhs dense matrix data + * \param rhs_data csr matrix data + * \param rhs_indices csr matrix column indices + * \param rhs_indptr csr matrix row pointer + * \param out output matrix data + * \param lhs_num_cols lhs dns matrix number of columns + * \param out_num_rows output dns matrix number of rows + * \param out_num_cols output dns matrix number of columns + */ + template + __device__ __forceinline__ static void Map(int tid, + const IType* original_idx_ptr, + const DType* csr_data_ptr, + const CType* csr_rows_ptr, + DType* csc_data_ptr, + IType* csc_indices_ptr, + const nnvm::dim_t nnz) { + using nnvm::dim_t; + if (tid < nnz) { + const IType origin = original_idx_ptr[tid]; + csc_data_ptr[tid] = csr_data_ptr[origin]; + csc_indices_ptr[tid] = csr_rows_ptr[origin]; + } + } +}; + +/*! + * \brief GPU Kernel of dot(dns, csr.T) = dns + * Parallelization by output elements: 1 thread/element + */ +struct DotDnsCsrTransDnsKernel { + /*! + * \brief + * \param tid global thread id + * \param lhs_data lhs dense matrix data + * \param rhs_data csr matrix data + * \param rhs_indices csr matrix column indices + * \param rhs_indptr csr matrix row pointer + * \param out output matrix data + * \param lhs_num_cols lhs dns matrix number of columns + * \param out_num_rows output dns matrix number of rows + * \param out_num_cols output dns matrix number of columns + */ + template + __device__ __forceinline__ static void Map(int tid, + const DType* lhs_data, + const DType* rhs_data, + const IType* rhs_indices, + const CType* rhs_indptr, + DType* out, + const nnvm::dim_t lhs_num_cols, + const nnvm::dim_t out_num_rows, + const nnvm::dim_t out_num_cols) { + using nnvm::dim_t; + if (tid < out_num_rows*out_num_cols) { + const dim_t i = static_cast(tid) % out_num_rows; // i = row this thread computes + const dim_t k = static_cast(tid) / out_num_rows; // k = col this thread computes + // Compute inner product of i-th row and k-th col + DType sum = 0; + for (CType col_id = rhs_indptr[k]; col_id < rhs_indptr[k + 1]; ++col_id) { + sum += lhs_data[i * lhs_num_cols + rhs_indices[col_id]] * rhs_data[col_id]; + } + out[i * out_num_cols + k] = sum; + } + } +}; + /*! * \brief GPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2 */ @@ -895,6 +990,156 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx, }); } +// Returns integer log2(a) rounded up +inline int log2i(size_t a) { + int k = 1; + while (a >>= 1) k++; + return k; +} + +/* + * \brief GPU Impl of dot(dns, csr) = csr + */ +inline void DotDnsCsrCsrImpl(const OpContext& ctx, const gpu& gpu_dev, + const TBlob& lhs, const NDArray& rhs, + const OpReqType req, NDArray* ret) { + LOG(FATAL) << "dot(dense, csr) = csr is not implemented on GPU"; +} + +/* + * \brief GPU Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns + */ +inline void DotDnsCsrDnsImpl(const OpContext& ctx, const gpu& gpu_dev, + const TBlob& dns, const NDArray& rhs, + const OpReqType req, NDArray* ret, + const bool transpose_b) { + if (req == kNullOp) { + return; + } + CHECK_EQ(req, kWriteTo); + CHECK_EQ(rhs.storage_type(), kCSRStorage); + + using namespace mshadow; + using namespace mshadow::expr; + using nnvm::dim_t; + + /* Initialize data structures */ + mshadow::Stream* s = ctx.get_stream(); + TBlob csr_data = rhs.data(); + TBlob csr_indices = rhs.aux_data(csr::kIdx); + TBlob csr_indptr = rhs.aux_data(csr::kIndPtr); + if (!rhs.storage_initialized()) { + FillZerosCsrImpl(s, *ret); + return; + } + + MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type + const nnvm::dim_t out_num_rows = ret->shape()[0]; + const nnvm::dim_t out_num_cols = ret->shape()[1]; + // if dot(dense, csr) = dns, transform to csc first + if (!transpose_b) { + const nnvm::dim_t num_csr_rows = rhs.shape()[0]; + const nnvm::dim_t num_csr_cols = rhs.shape()[1]; + const nnvm::dim_t num_dns_rows = dns.shape_[0]; + const nnvm::dim_t nnz = rhs.storage_shape().Size(); + + IType* original_idx_ptr = nullptr; + IType* csc_indices_ptr = nullptr; + IType* csc_cols_ptr = nullptr; + CType* csr_rows_ptr = nullptr; + CType* csc_indptr_ptr = nullptr; + DType* csc_data_ptr = nullptr; + char* temp_storage_ptr = nullptr; + size_t original_idx_bytes = nnz*sizeof(IType); + size_t csc_indices_bytes = nnz*sizeof(IType); + size_t csc_cols_bytes = nnz*sizeof(IType); + size_t csr_rows_bytes = nnz*sizeof(CType); + size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType); + size_t csc_data_bytes = nnz*sizeof(DType); + size_t scan_temp_storage_bytes = 0; + size_t temp_storage_bytes = SortByKeyWorkspaceSize(nnz); + IType* csr_indices_ptr = csr_indices.dptr(); + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + scan_temp_storage_bytes, + csc_indptr_ptr, + csc_indptr_ptr, + num_csr_cols+1, + mshadow::Stream::GetStream(s)); + temp_storage_bytes = std::max(temp_storage_bytes, scan_temp_storage_bytes); + temp_storage_bytes += (sizeof(dim_t) - temp_storage_bytes % sizeof(dim_t)); + size_t total_workspace_bytes = + original_idx_bytes + csc_indices_bytes + csc_cols_bytes + csr_rows_bytes + + csc_indptr_bytes + csc_data_bytes + temp_storage_bytes; + total_workspace_bytes += (sizeof(IType) - total_workspace_bytes % sizeof(IType)); + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(total_workspace_bytes), s); + original_idx_ptr = reinterpret_cast(workspace.dptr_); + csc_indices_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + csc_cols_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + csc_indices_bytes); + csr_rows_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + csc_indices_bytes + csc_cols_bytes); + csc_indptr_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + csc_indices_bytes + csc_cols_bytes + + csr_rows_bytes); + temp_storage_ptr = workspace.dptr_ + original_idx_bytes + csc_indices_bytes + + csc_cols_bytes + csr_rows_bytes + csc_indptr_bytes; + csc_data_ptr = reinterpret_cast( + workspace.dptr_ + total_workspace_bytes - csc_data_bytes); + + // Fill original_idx + mxnet_op::Kernel::Launch( + s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr); + // Fill csc_cols with copy of csr_indices + mxnet_op::Kernel, gpu>::Launch( + s, nnz, csc_cols_ptr, csr_indices_ptr); + // Allocate the tensors needed for SortByKey + Tensor original_idx(original_idx_ptr, Shape1(nnz), s); + Tensor csc_cols(csc_cols_ptr, Shape1(nnz), s); + Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); + + int num_bits = log2i(num_csr_cols - 1); + SortByKey(csc_cols, original_idx, true, &temp_storage, 0, num_bits); + + // Scatter csr indptr to row id + mxnet_op::Kernel::Launch( + s, num_csr_rows, csr_indptr.dptr(), csr_rows_ptr, num_csr_rows); + // Reset indptr to zero + mxnet_op::Kernel::Launch(s, num_csr_cols+1, csc_indptr_ptr); + // Histogram on the sorted cols + mxnet_op::Kernel::Launch( + s, nnz, csc_indptr_ptr, csc_cols_ptr, nnz); + // Scan the bin counts for every column to get csc_indptr + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + temp_storage_bytes, + csc_indptr_ptr, + csc_indptr_ptr, + num_csr_cols+1, + mshadow::Stream::GetStream(s)); + // Assign data to csc matrix arrays + mxnet_op::Kernel::Launch( + s, nnz, original_idx_ptr, csr_data.dptr(), csr_rows_ptr, csc_data_ptr, + csc_indices_ptr, nnz); + + mxnet_op::Kernel::Launch( + s, out_num_rows * out_num_cols, dns.dptr(), + csc_data_ptr, csc_indices_ptr, csc_indptr_ptr, + ret->data().dptr(), dns.shape_[1], + out_num_rows, out_num_cols); + } else { + mxnet_op::Kernel::Launch( + s, out_num_rows * out_num_cols, dns.dptr(), + csr_data.dptr(), csr_indices.dptr(), + csr_indptr.dptr(), ret->data().dptr(), + dns.shape_[1], out_num_rows, out_num_cols); + } + }); + }); + }); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 83571d9e4d2c..2c9a483567f8 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -45,6 +45,7 @@ namespace op { struct DotParam : public dmlc::Parameter { bool transpose_a; bool transpose_b; + dmlc::optional forward_stype; DMLC_DECLARE_PARAMETER(DotParam) { DMLC_DECLARE_FIELD(transpose_a) .describe("If true then transpose the first input before dot.") @@ -52,6 +53,15 @@ struct DotParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(transpose_b) .describe("If true then transpose the second input before dot.") .set_default(false); + DMLC_DECLARE_FIELD(forward_stype) + .describe("The desired storage type of the forward output given by user, if the" + "combination of input storage types and this hint does not match" + "any implemented ones, the dot operator will perform fallback operation" + "and still produce an output of the desired storage type.") + .add_enum("default", kDefaultStorage) + .add_enum("row_sparse", kRowSparseStorage) + .add_enum("csr", kCSRStorage) + .set_default(dmlc::optional()); } }; @@ -217,35 +227,57 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, bool only_lhs_transpose = param.transpose_a && !param.transpose_b; bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage; + bool hint_has_value = param.forward_stype.has_value(); + NDArrayStorageType target_stype = hint_has_value ? + static_cast(param.forward_stype.value()) : + kUndefinedStorage; if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) { // dns, dns -> dns - dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, - DispatchMode::kFCompute); + target_stype = hint_has_value ? target_stype : kDefaultStorage; + if (target_stype == kDefaultStorage) { + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFCompute); + } } - if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && - (rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage)) { + if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) { // csr.T, rsp/dns -> rsp - dispatched = storage_type_assign(&out_stype, kRowSparseStorage, - dispatch_mode, DispatchMode::kFComputeEx); + target_stype = hint_has_value ? target_stype : kRowSparseStorage; + if (target_stype == kRowSparseStorage) { + dispatched = storage_type_assign(&out_stype, kRowSparseStorage, + dispatch_mode, DispatchMode::kFComputeEx); + } } if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns && !param.transpose_a && !param.transpose_b) { // csr, rsp/dns -> dns - dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, - DispatchMode::kFComputeEx); + target_stype = hint_has_value ? target_stype : kDefaultStorage; + if (target_stype == kDefaultStorage) { + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } } if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && - !param.transpose_a && !param.transpose_b) { - // dns, csr -> csr - const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; - const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback - : DispatchMode::kFComputeEx; - dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, - dispatch_ex); + !param.transpose_a) { + target_stype = hint_has_value ? target_stype : kCSRStorage; + // dns, csr -> csr on CPU + if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) { + if (target_stype == kCSRStorage) { + dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } + // dns, csr/csr.T -> dns on GPU + } else if (dev_mask == mshadow::gpu::kDevMask) { + if (target_stype == kDefaultStorage) { + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } + } } if (!dispatched) { - dispatched = dispatch_fallback(out_attrs, dispatch_mode); + target_stype = (target_stype == kUndefinedStorage)? kDefaultStorage : target_stype; + dispatched = storage_type_assign(&out_stype, target_stype, dispatch_mode, + DispatchMode::kFComputeFallback); } return dispatched; } @@ -291,6 +323,15 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, dispatched = true; } } + if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a && + lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && + ograd_stype == kDefaultStorage) { + if (type_assign(&lhs_grad_stype, kDefaultStorage) && + type_assign(&rhs_grad_stype, kDefaultStorage)) { + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + dispatched = true; + } + } if (!dispatched) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); } @@ -897,10 +938,9 @@ inline void DotCsrRspRspImpl(const OpContext& ctx, } /* - * \brief CPU Impl of dot(dns, csr) = csr + * \brief Impl of dot(dns, csr) = csr */ -template -inline void DotDnsCsrCsrImpl(const OpContext& ctx, +inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, const TBlob& lhs, const NDArray& rhs, const OpReqType req, NDArray* ret) { if (kNullOp == req) return; @@ -986,6 +1026,16 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, }); } +/* + * \brief Impl of dot(dns, csr) = dense (GPU only) + */ +inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev, + const TBlob& dns, const NDArray& rhs, + const OpReqType req, NDArray* ret, + const bool transpose_b) { + LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU"; +} + inline bool DotShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -1039,7 +1089,6 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "transposing rhs of the sparse dot op is not supported"; CHECK_EQ(inputs[0].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs"; CHECK_EQ(inputs[1].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs"; auto lhs_stype = inputs[0].storage_type(); @@ -1065,7 +1114,11 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, out_stype == kCSRStorage && !(param.transpose_a || param.transpose_b)) { NDArray ret = outputs[0]; - DotDnsCsrCsrImpl(ctx, inputs[0].data(), inputs[1], req[0], &ret); + DotDnsCsrCsrImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret); + } else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && + out_stype == kDefaultStorage && !(param.transpose_a)) { + NDArray ret = outputs[0]; + DotDnsCsrDnsImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret, param.transpose_b); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } @@ -1080,16 +1133,18 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 2U); CHECK_EQ(req.size(), 2U); - CHECK_EQ(kNullOp, req[0]) - << "sparse dot does not support computing the gradient of the csr/lhs"; + CHECK(!(req[0] != kNullOp && outputs[0].storage_type() == kCSRStorage)) + << "sparse dot does not support computing the gradient of csr"; + CHECK(!(req[1] != kNullOp && outputs[1].storage_type() == kCSRStorage)) + << "sparse dot does not support computing the gradient of csr"; CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; CHECK_EQ(inputs[0].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs"; CHECK_EQ(inputs[1].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs"; const auto ograd_stype = inputs[0].storage_type(); const auto lhs_stype = inputs[1].storage_type(); + const auto rhs_stype = inputs[2].storage_type(); const auto grad_rhs_stype = outputs[1].storage_type(); if (ograd_stype == kDefaultStorage // ograd dns format && lhs_stype == kCSRStorage // csr input lhs of the op @@ -1108,6 +1163,11 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, && grad_rhs_stype == kDefaultStorage && !param.transpose_b) { TBlob ret = outputs[1].data(); DotCsrRspDnsImpl(ctx, xpu(), inputs[1], inputs[0], req[1], !param.transpose_a, &ret); + } else if (ograd_stype == kDefaultStorage && // ograd dns format + lhs_stype == kDefaultStorage && // lhs dns format + rhs_stype == kCSRStorage && !param.transpose_a) { + NDArray ret = outputs[0]; + DotDnsCsrDnsImpl(ctx, xpu(), inputs[0].data(), inputs[2], req[0], &ret, !param.transpose_b); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index 834b559b86f6..9d62d0daa391 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -51,13 +51,19 @@ NNVM_REGISTER_OP(dot) dot(x,y)[0,0,1,1] = 0 sum(x[0,0,:]*y[:,1,1]) = 0 -The storage type of ``dot`` output depends on storage types of inputs and transpose options: +The storage type of ``dot`` output depends on storage types of inputs, transpose options and given +hint for output storage type: +Implemented sprase operations include: - dot(csr, default) = default -- dot(csr.T, default) = row_sparse +- dot(csr, default, transpose_a=True) = row_sparse - dot(csr, row_sparse) = default -- dot(default, csr) = csr -- otherwise, ``dot`` generates output with default storage +- dot(default, csr) = csr on CPU only +- dot(default, csr, forward_stype='default') = default on GPU only +- dot(default, csr, transpose_b=True, forward_stype='default') = default on GPU only +- if the combination of input storage types and forward_stype_hint +- does not match any of the above patterns, +- dot will generate output with default storage )doc" ADD_FILELINE) .set_num_inputs(2) diff --git a/src/operator/tensor/util/tensor_util-inl.cuh b/src/operator/tensor/util/tensor_util-inl.cuh index f38e8e117c94..c9ee625af0c8 100644 --- a/src/operator/tensor/util/tensor_util-inl.cuh +++ b/src/operator/tensor/util/tensor_util-inl.cuh @@ -231,6 +231,22 @@ struct MarkCsrColWarpKernel { } }; +/*! + * \brief GPU Kernel to perform histogram (input types should be integer types) + * Parallelization by output elements: 1 thread/input element + */ +struct HistogramKernel { + template + __device__ __forceinline__ static void Map(int tid, + IType* target, + const CType* source, + const nnvm::dim_t num_elems) { + if (tid < num_elems) { + atomicAdd(&target[source[tid]], 1); + } + } +}; + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 0a5bda831d9c..abb27de1dc71 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -738,8 +738,8 @@ def test_lambda(): input_data = mx.nd.random.uniform(shape=(2, 3, 5, 7)) out1, out2, out3 = net1(input_data), net2(input_data), net3(input_data) - assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-3) - assert_almost_equal(out1.asnumpy(), out3.asnumpy(), rtol=1e-3) + assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-3, atol=1e-3) + assert_almost_equal(out1.asnumpy(), out3.asnumpy(), rtol=1e-3, atol=1e-3) @with_seed() diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 484c98643d91..16b52f60ceb9 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1208,6 +1208,27 @@ def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad= @with_seed() def test_sparse_dot(): + def test_infer_forward_stype(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_a, trans_b): + all_stypes = ["default", "csr", "row_sparse"] + lhs_nd = rand_ndarray(lhs_shape, 'default', density=lhs_density) + rhs_nd = rand_ndarray(rhs_shape, 'default', density=rhs_density) + out_nd = mx.nd.dot(lhs_nd, rhs_nd, transpose_a=trans_a, transpose_b=trans_b) + out_np = out_nd.asnumpy() + for lhs_stype in all_stypes: + for rhs_stype in all_stypes: + for forward_stype in all_stypes: + lhs = lhs_nd.tostype(lhs_stype) + rhs = rhs_nd.tostype(rhs_stype) + out = mx.nd.dot(lhs, rhs, forward_stype=forward_stype, + transpose_a=trans_a, transpose_b=trans_b) + assert_almost_equal(out.tostype('default').asnumpy(), out_np, rtol=1e-4, atol=1e-5) + lhs_var = mx.symbol.Variable('lhs', stype=lhs_stype) + rhs_var = mx.symbol.Variable('rhs', stype=rhs_stype) + out = mx.symbol.sparse.dot(lhs_var, rhs_var, + forward_stype=forward_stype, + transpose_a=trans_a, transpose_b=trans_b) + location = {'lhs': lhs, 'rhs': rhs} + check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4) def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_density): lhs_nd = rand_ndarray(lhs_shape, 'csr', density=lhs_density, shuffle_csr_indices=False) lhs_dns = lhs_nd.tostype('default') @@ -1239,25 +1260,39 @@ def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=F rhs_nd = rand_ndarray(rhs_shape, stype='csr', density=rhs_density) rhs_dns = rhs_nd.tostype('default') - out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs) - out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs) + if default_context() == mx.cpu(): + forward_stype = 'csr' + else: + forward_stype = 'default' + out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype) + out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype) out_np = out_dns.asnumpy() assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) # test symbolic forward lhs = mx.symbol.Variable('lhs', stype='default') rhs = mx.symbol.Variable('rhs', stype='csr') - out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs) + out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype) location = {'lhs': lhs_nd, 'rhs': rhs_nd} check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4) - # test symbolic backward - backward_trans = not trans_lhs - rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy() - expected = {'rhs': rhs_backward_grad} - check_symbolic_backward(out, location, [out_np], expected, - grad_req={'lhs': 'null', 'rhs': 'write'}, - rtol=1e-3, atol=1e-4) + if default_context() == mx.cpu(): + # test symbolic backward + backward_trans = not trans_lhs + rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy() + if trans_rhs is True: + rhs_backward_grad = rhs_backward_grad.T + expected = {'rhs': rhs_backward_grad} + check_symbolic_backward(out, location, [out_np], expected, + grad_req={'lhs': 'null', 'rhs': 'write'}, + rtol=1e-3, atol=1e-4) + else: + transpose_b = not trans_rhs + lhs_backward_grad = mx.nd.dot(out_dns, rhs_dns, transpose_b=transpose_b) + expected = {'lhs': lhs_backward_grad.asnumpy()} + check_symbolic_backward(out, location, [out_np], expected, + grad_req={'lhs': 'write', 'rhs': 'null'}, + rtol=1e-3, atol=1e-4) def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): """Test for nnr_out = 0. Before the fix, the test would fail.""" @@ -1276,7 +1311,7 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): sps_out = mx.nd.sparse.dot(lhs.tostype('csr'), rhs.tostype('row_sparse'), transpose_a=trans_lhs) assert same(dns_out.asnumpy(), sps_out.asnumpy()) - density = [1.00, 0.50, 0.01] + density = [1.00, 0.5, 0.01] for lhs_d in density: lhs_shape = rand_shape_2d(50, 200) rhs_d = 1 @@ -1285,15 +1320,52 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d) # (scalar kernel) test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(50, 200)), lhs_d, lhs_d) + test_dot_dns_csr(lhs_shape, (rnd.randint(50, 200), lhs_shape[1]), lhs_d, lhs_d, trans_rhs=True) for rhs_d in density: test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d) - + test_infer_forward_stype(lhs_shape, (lhs_shape[1], rnd.randint(10, 20)), + lhs_d, rhs_d, False, False) + test_infer_forward_stype(lhs_shape, (rnd.randint(10, 20), lhs_shape[1]), + lhs_d, rhs_d, False, True) + test_infer_forward_stype(lhs_shape, (lhs_shape[0], rnd.randint(10, 20)), + lhs_d, rhs_d, True, False) + test_infer_forward_stype(lhs_shape, (rnd.randint(10, 20), lhs_shape[0]), + lhs_d, rhs_d, True, True) test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40) test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40) +@with_seed() +def test_sparse_dot_determinism(): + def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b): + lhs_row = rnd.randint(50, 100) + lhs_col = rnd.randint(50, 100) + if transpose_a: + if transpose_b: + rhs_shape = (rnd.randint(50, 100), lhs_row) + else: + rhs_shape = (lhs_row, rnd.randint(50, 100)) + else: + if transpose_b: + rhs_shape = (rnd.randint(50, 100), lhs_col) + else: + rhs_shape = (lhs_col, rnd.randint(50, 100)) + if default_context() == mx.cpu(): + forward_stype = 'csr' + else: + forward_stype = 'default' + lhs_shape = (lhs_row, lhs_col) + lhs = rand_ndarray(lhs_shape, lhs_stype, density=lhs_density) + rhs = rand_ndarray(rhs_shape, rhs_stype, density=rhs_density) + res1 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b, forward_stype=forward_stype) + res2 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b, forward_stype=forward_stype) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.0, atol=0.0) + test_dot_determinism('default', 'csr', 1.0, 0.1, False, False) + test_dot_determinism('default', 'csr', 1.0, 0.1, False, True) + + @with_seed() def test_sparse_slice(): def check_csr_slice(shape, slice_input):